From 8fe4a0a511013142fda7cae23b3c013c8403c6c9 Mon Sep 17 00:00:00 2001 From: Rajat Sen Date: Wed, 8 Jan 2025 18:51:47 +0000 Subject: [PATCH] updating benchmarks to point to v2.0 --- .../extended_benchmarks/run_timesfm.py | 34 ++++++++++++++----- .../long_horizon_benchmarks/run_eval.py | 15 ++++---- src/timesfm/timesfm_base.py | 3 ++ 3 files changed, 38 insertions(+), 14 deletions(-) diff --git a/experiments/extended_benchmarks/run_timesfm.py b/experiments/extended_benchmarks/run_timesfm.py index 589b63d..e8878c8 100644 --- a/experiments/extended_benchmarks/run_timesfm.py +++ b/experiments/extended_benchmarks/run_timesfm.py @@ -54,7 +54,10 @@ "hospital", ] -context_dict = { + +context_dict_v2 = {} + +context_dict_v1 = { "cif_2016": 32, "tourism_yearly": 64, "covid_deaths": 64, @@ -71,7 +74,7 @@ "m4_yearly": 64, } -_MODEL_PATH = flags.DEFINE_string("model_path", "google/timesfm-1.0-200m", +_MODEL_PATH = flags.DEFINE_string("model_path", "google/timesfm-2.0-500m-jax", "Path to model") _BATCH_SIZE = flags.DEFINE_integer("batch_size", 64, "Batch size") _HORIZON = flags.DEFINE_integer("horizon", 128, "Horizon") @@ -84,14 +87,27 @@ def main(): results_list = [] + model_path = _MODEL_PATH.value + num_layers = 20 + max_context_len = 512 + use_positional_embedding = True + context_dict = context_dict_v1 + if "2.0" in model_path: + num_layers = 50 + use_positional_embedding = False + max_context_len = 2048 + context_dict = context_dict_v2 + tfm = timesfm.TimesFm( hparams=timesfm.TimesFmHparams( - backend=_BACKEND.value, - per_core_batch_size=_BATCH_SIZE.value, - horizon_len=_HORIZON.value, + backend="gpu", + per_core_batch_size=32, + horizon_len=128, + num_layers=num_layers, + context_len=max_context_len, + use_positional_embedding=use_positional_embedding, ), - checkpoint=timesfm.TimesFmCheckpoint( - huggingface_repo_id=_MODEL_PATH.value), + checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=model_path), ) run_id = np.random.randint(100000) model_name = "timesfm" @@ -102,7 +118,8 @@ def main(): if dataset in context_dict: context_len = context_dict[dataset] else: - context_len = 512 + context_len = max_context_len + train_df = exp.train_df freq = exp.freq init_time = time.time() @@ -113,6 +130,7 @@ def main(): model_name=model_name, forecast_context_len=context_len, num_jobs=_NUM_JOBS.value, + normalize=True, ) total_time = time.time() - init_time time_df = pd.DataFrame({"time": [total_time], "model": model_name}) diff --git a/experiments/long_horizon_benchmarks/run_eval.py b/experiments/long_horizon_benchmarks/run_eval.py index b715257..24a0368 100644 --- a/experiments/long_horizon_benchmarks/run_eval.py +++ b/experiments/long_horizon_benchmarks/run_eval.py @@ -32,8 +32,8 @@ "Batch size for the randomly sampled batch") _DATASET = flags.DEFINE_string("dataset", "etth1", "The name of the dataset.") -_MODEL_PATH = flags.DEFINE_string("model_path", "./timesfm_q10_20240501", - "The name of the dataset.") +_MODEL_PATH = flags.DEFINE_string("model_path", "google/timesfm-2.0-500m-jax", + "The name of the model.") _DATETIME_COL = flags.DEFINE_string("datetime_col", "date", "Column having datetime.") _NUM_COV_COLS = flags.DEFINE_list("num_cov_cols", None, @@ -43,7 +43,7 @@ _TS_COLS = flags.DEFINE_list("ts_cols", None, "Columns of time-series features") _NORMALIZE = flags.DEFINE_bool("normalize", True, "normalize data for eval or not") -_CONTEXT_LEN = flags.DEFINE_integer("context_len", 512, +_CONTEXT_LEN = flags.DEFINE_integer("context_len", 2048, "Length of the context window") _PRED_LEN = flags.DEFINE_integer("pred_len", 96, "prediction length.") _BACKEND = flags.DEFINE_string("backend", "gpu", "backend to use") @@ -177,9 +177,12 @@ def eval(): else: model = timesfm.TimesFm( hparams=timesfm.TimesFmHparams( - backend=_BACKEND.value, - per_core_batch_size=_BATCH_SIZE.value, - horizon_len=_PRED_LEN.value, + backend="gpu", + per_core_batch_size=32, + horizon_len=128, + num_layers=50, + context_len=_CONTEXT_LEN.value, + use_positional_embedding=False, ), checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=model_path), ) diff --git a/src/timesfm/timesfm_base.py b/src/timesfm/timesfm_base.py index 3e05dd4..9e2c0ec 100644 --- a/src/timesfm/timesfm_base.py +++ b/src/timesfm/timesfm_base.py @@ -636,6 +636,7 @@ def forecast_on_df( model_name: str = "timesfm", window_size: int | None = None, num_jobs: int = 1, + normalize: bool = False, verbose: bool = True, ) -> pd.DataFrame: """Forecasts on a list of time series. @@ -654,6 +655,7 @@ def forecast_on_df( window_size: window size of trend + residual decomposition. If None then we do not do decomposition. num_jobs: number of parallel processes to use for dataframe processing. + normalize: normalize context before forecasting or not. verbose: output model states in terminal. Returns: @@ -698,6 +700,7 @@ def forecast_on_df( freq_inps = [freq_map(freq)] * len(new_inputs) _, full_forecast = self.forecast(new_inputs, freq=freq_inps, + normalize=normalize, window_size=window_size) if verbose: print("Finished forecasting.")