Skip to content

Commit

Permalink
Merge pull request #210 from google-research/rajat_dev
Browse files Browse the repository at this point in the history
updating benchmarks to point to v2.0
  • Loading branch information
rajatsen91 authored Jan 8, 2025
2 parents ff2a350 + 8fe4a0a commit e628f0f
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 14 deletions.
34 changes: 26 additions & 8 deletions experiments/extended_benchmarks/run_timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@
"hospital",
]

context_dict = {

context_dict_v2 = {}

context_dict_v1 = {
"cif_2016": 32,
"tourism_yearly": 64,
"covid_deaths": 64,
Expand All @@ -71,7 +74,7 @@
"m4_yearly": 64,
}

_MODEL_PATH = flags.DEFINE_string("model_path", "google/timesfm-1.0-200m",
_MODEL_PATH = flags.DEFINE_string("model_path", "google/timesfm-2.0-500m-jax",
"Path to model")
_BATCH_SIZE = flags.DEFINE_integer("batch_size", 64, "Batch size")
_HORIZON = flags.DEFINE_integer("horizon", 128, "Horizon")
Expand All @@ -84,14 +87,27 @@

def main():
results_list = []
model_path = _MODEL_PATH.value
num_layers = 20
max_context_len = 512
use_positional_embedding = True
context_dict = context_dict_v1
if "2.0" in model_path:
num_layers = 50
use_positional_embedding = False
max_context_len = 2048
context_dict = context_dict_v2

tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend=_BACKEND.value,
per_core_batch_size=_BATCH_SIZE.value,
horizon_len=_HORIZON.value,
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
num_layers=num_layers,
context_len=max_context_len,
use_positional_embedding=use_positional_embedding,
),
checkpoint=timesfm.TimesFmCheckpoint(
huggingface_repo_id=_MODEL_PATH.value),
checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=model_path),
)
run_id = np.random.randint(100000)
model_name = "timesfm"
Expand All @@ -102,7 +118,8 @@ def main():
if dataset in context_dict:
context_len = context_dict[dataset]
else:
context_len = 512
context_len = max_context_len

train_df = exp.train_df
freq = exp.freq
init_time = time.time()
Expand All @@ -113,6 +130,7 @@ def main():
model_name=model_name,
forecast_context_len=context_len,
num_jobs=_NUM_JOBS.value,
normalize=True,
)
total_time = time.time() - init_time
time_df = pd.DataFrame({"time": [total_time], "model": model_name})
Expand Down
15 changes: 9 additions & 6 deletions experiments/long_horizon_benchmarks/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
"Batch size for the randomly sampled batch")
_DATASET = flags.DEFINE_string("dataset", "etth1", "The name of the dataset.")

_MODEL_PATH = flags.DEFINE_string("model_path", "./timesfm_q10_20240501",
"The name of the dataset.")
_MODEL_PATH = flags.DEFINE_string("model_path", "google/timesfm-2.0-500m-jax",
"The name of the model.")
_DATETIME_COL = flags.DEFINE_string("datetime_col", "date",
"Column having datetime.")
_NUM_COV_COLS = flags.DEFINE_list("num_cov_cols", None,
Expand All @@ -43,7 +43,7 @@
_TS_COLS = flags.DEFINE_list("ts_cols", None, "Columns of time-series features")
_NORMALIZE = flags.DEFINE_bool("normalize", True,
"normalize data for eval or not")
_CONTEXT_LEN = flags.DEFINE_integer("context_len", 512,
_CONTEXT_LEN = flags.DEFINE_integer("context_len", 2048,
"Length of the context window")
_PRED_LEN = flags.DEFINE_integer("pred_len", 96, "prediction length.")
_BACKEND = flags.DEFINE_string("backend", "gpu", "backend to use")
Expand Down Expand Up @@ -177,9 +177,12 @@ def eval():
else:
model = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend=_BACKEND.value,
per_core_batch_size=_BATCH_SIZE.value,
horizon_len=_PRED_LEN.value,
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
num_layers=50,
context_len=_CONTEXT_LEN.value,
use_positional_embedding=False,
),
checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=model_path),
)
Expand Down
3 changes: 3 additions & 0 deletions src/timesfm/timesfm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ def forecast_on_df(
model_name: str = "timesfm",
window_size: int | None = None,
num_jobs: int = 1,
normalize: bool = False,
verbose: bool = True,
) -> pd.DataFrame:
"""Forecasts on a list of time series.
Expand All @@ -654,6 +655,7 @@ def forecast_on_df(
window_size: window size of trend + residual decomposition. If None then
we do not do decomposition.
num_jobs: number of parallel processes to use for dataframe processing.
normalize: normalize context before forecasting or not.
verbose: output model states in terminal.
Returns:
Expand Down Expand Up @@ -698,6 +700,7 @@ def forecast_on_df(
freq_inps = [freq_map(freq)] * len(new_inputs)
_, full_forecast = self.forecast(new_inputs,
freq=freq_inps,
normalize=normalize,
window_size=window_size)
if verbose:
print("Finished forecasting.")
Expand Down

0 comments on commit e628f0f

Please sign in to comment.