diff --git a/generic_trainer/configs.py b/generic_trainer/configs.py index 7fb284f..bcb2ce5 100644 --- a/generic_trainer/configs.py +++ b/generic_trainer/configs.py @@ -207,7 +207,7 @@ def string_to_object(self, key, value): config = ModelParameters() config.deserizalize_dict(value) value = config - if key == 'parallelization_params' and isinstance(value, dict): + elif key == 'parallelization_params' and isinstance(value, dict): config = ParallelizationConfig() config.deserizalize_dict(value) value = config @@ -238,6 +238,7 @@ class InferenceConfig(Config): """ def string_to_object(self, key, value): + value = super().string_to_object(key, value) if key == 'model_save_dir': self.pretrained_model_path = os.path.join(value, 'best_model.pth') return value