diff --git a/generic_trainer/configs.py b/generic_trainer/configs.py index b5a910a..7a128b1 100644 --- a/generic_trainer/configs.py +++ b/generic_trainer/configs.py @@ -79,7 +79,26 @@ def string_to_object(self, key, value): :return: object. """ # Value is a class handle - if isinstance(value, (list, tuple)): + if isinstance(value, dict): + # Only convert the dict to an OptionContainer object if they are supposed to. Otherwise, leave it as a dict. + if 'model_params' in key or key in ['loss_tracker_params', 'parallelization_params']: + assert 'config_class' in value.keys(), ('The value of {} is supposed to be an object of a subclass ' + 'of OptionContainer, but I cannot find the ' + 'class name.'.format(key)) + try: + config = globals()[value['config_class']]() + config.deserizalize_dict(value) + value = config + except KeyError as e: + raise ModuleNotFoundError( + "When loading {} from JSON, the following error occurred when attempting to create the " + "OptionContainer object for it:'\n{}\n" + "To create an OptionContainer object, its class name must be in the global namespace. You can " + "import the proper classes in your driver script using from ... import ..., and pass " + "globals() to load_from_json:\n" + " configs.load_from_json(filename, namespace=globals())\n".format(key, e) + ) + elif isinstance(value, (list, tuple)): value = [self.string_to_object(key, v) for v in value] elif isinstance(value, str) and (res := re.match(r"", value)): class_import_path = res.groups()[0].split('.') @@ -105,7 +124,9 @@ def object_to_string(self, key, value): :return: str. """ if isinstance(value, OptionContainer): + config_class_name = value.__class__.__name__ value = value.get_serializable_dict() + value['config_class'] = config_class_name elif isinstance(value, (dict, int, float, bool)): value = value elif isinstance(value, (tuple, list)): @@ -123,14 +144,7 @@ def object_to_string(self, key, value): @dataclasses.dataclass class ModelParameters(OptionContainer): - - def string_to_object(self, key, value): - value = super().string_to_object(key, value) - if 'model_params' in key and isinstance(value, dict): - config = ModelParameters() - config.deserizalize_dict(value) - value = config - return value + pass # ============================= @@ -201,18 +215,6 @@ class Config(OptionContainer): Task type. Can be 'classification', 'regression'. Currently this only affects the logging of the loss tracker. """ - def string_to_object(self, key, value): - value = super().string_to_object(key, value) - if 'model_params' in key and isinstance(value, dict): - config = ModelParameters() - config.deserizalize_dict(value) - value = config - elif key == 'parallelization_params' and isinstance(value, dict): - config = ParallelizationConfig() - config.deserizalize_dict(value) - value = config - return value - @dataclasses.dataclass class InferenceConfig(Config): @@ -355,17 +357,13 @@ def string_to_object(self, key, value): try: value = eval(value) except Exception as e: - print( + raise ModuleNotFoundError( "When loading loss_function from JSON, the following error occurred:'\n{}\n" "To create a loss function object, its class name must be in the global namespace. You can " "import the proper classes in your driver script using from ... import ..., and pass " "globals() to load_from_json:\n" " configs.load_from_json(filename, namespace=globals())\n".format(e) ) - elif key == 'loss_tracker_params': - configs = LossTrackerParameters() - configs.deserizalize_dict(value) - value = configs return value @dataclasses.dataclass