From 75be2ae073d241a555dfa1ccdf0ee91e5923e3e3 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 26 Apr 2024 16:27:44 -0500 Subject: [PATCH] Save_ONNX defaults to False --- generic_trainer/configs.py | 2 +- generic_trainer/trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/generic_trainer/configs.py b/generic_trainer/configs.py index bdd3fff..8fee46f 100644 --- a/generic_trainer/configs.py +++ b/generic_trainer/configs.py @@ -252,7 +252,7 @@ class TrainingConfig(Config): automatic_mixed_precision: bool = False """Automatic mixed precision and gradient scaling are enabled if True.""" - save_onnx: bool = True + save_onnx: bool = False """If True, ONNX models are saved along with state dicts.""" diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index bc6e3eb..263a7de 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -836,7 +836,7 @@ def save_model(self, path, subcomponent=None): m = getattr(m, subcomponent) torch.save(m.state_dict(), path) - def update_saved_model(self, filename='best_model.pth', save_configs=True, save_onnx=True, subcomponent=None, + def update_saved_model(self, filename='best_model.pth', save_configs=True, save_onnx=False, subcomponent=None, run_with_only_rank_0=True): """ Updates saved model if validation loss is minimum.