From 384e2b4547830cccb962f96a554fce581973b737 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Wed, 22 May 2024 14:30:21 -0500 Subject: [PATCH] Deserialize JSON and create class objects when loading --- generic_trainer/configs.py | 80 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 4 deletions(-) diff --git a/generic_trainer/configs.py b/generic_trainer/configs.py index 8bda505..a3d89d0 100644 --- a/generic_trainer/configs.py +++ b/generic_trainer/configs.py @@ -3,6 +3,8 @@ from typing import Any, Callable, Optional, Union import json import os +import re +import importlib import torch from torch.utils.data import Dataset @@ -15,6 +17,10 @@ @dataclasses.dataclass class OptionContainer: + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.globals = {} + def __str__(self): s = '' for key in self.__dict__.keys(): @@ -40,6 +46,10 @@ def get_serializable_dict(self): d[key] = v return d + def deserizalize_dict(self, d): + for key in d.keys(): + self.__dict__[key] = self.string_to_object(key, d[key]) + def dump_to_json(self, filename): try: f = open(filename, 'w') @@ -49,14 +59,16 @@ def dump_to_json(self, filename): except: print('Failed to dump json.') - def load_from_json(self, filename): + def load_from_json(self, filename, namespace=None): """ This function only overwrites entries contained in the JSON file. Unspecified entries are unaffected. """ + if namespace is not None: + for key in namespace.keys(): + globals()[key] = namespace[key] f = open(filename, 'r') d = json.load(f) - for key in d.keys(): - self.__dict__[key] = self.string_to_object(key, d[key]) + self.deserizalize_dict(d) f.close() def string_to_object(self, key, value): @@ -66,6 +78,21 @@ def string_to_object(self, key, value): :param value: str. :return: object. """ + # Value is a class handle + if isinstance(value, (list, tuple)): + value = [self.string_to_object(key, v) for v in value] + elif isinstance(value, str) and (res := re.match(r"", value)): + class_import_path = res.groups()[0].split('.') + value = getattr(importlib.import_module('.'.join(class_import_path[:-1])), class_import_path[-1]) + elif value in ['True', 'False']: + value = True if value == 'True' else False + else: + for caster in (int, float): + try: + value = caster(value) + break + except (ValueError, TypeError): + pass return value def object_to_string(self, key, value): @@ -77,8 +104,12 @@ def object_to_string(self, key, value): """ if isinstance(value, OptionContainer): value = value.get_serializable_dict() + elif isinstance(value, dict): + value = value elif isinstance(value, (tuple, list)): value = [self.object_to_string(key, x) for x in value] + elif value is None: + value = None else: value = str(value) return value @@ -90,7 +121,14 @@ def object_to_string(self, key, value): @dataclasses.dataclass class ModelParameters(OptionContainer): - pass + + def string_to_object(self, key, value): + value = super().string_to_object(key, value) + if 'model_params' in key and isinstance(value, dict): + config = ModelParameters() + config.deserizalize_dict(value) + value = config + return value # ============================= @@ -161,6 +199,18 @@ class Config(OptionContainer): Task type. Can be 'classification', 'regression'. Currently this only affects the logging of the loss tracker. """ + def string_to_object(self, key, value): + value = super().string_to_object(key, value) + if 'model_params' in key and isinstance(value, dict): + config = ModelParameters() + config.deserizalize_dict(value) + value = config + if key == 'parallelization_params' and isinstance(value, dict): + config = ParallelizationConfig() + config.deserizalize_dict(value) + value = config + return value + @dataclasses.dataclass class InferenceConfig(Config): @@ -185,6 +235,10 @@ class InferenceConfig(Config): processed variables. """ + def string_to_object(self, key, value): + if key == 'model_save_dir': + self.pretrained_model_path = os.path.join(value, 'best_model.pth') + @dataclasses.dataclass class TrainingConfig(Config): @@ -289,6 +343,24 @@ class TrainingConfig(Config): save_onnx: bool = False """If True, ONNX models are saved along with state dicts.""" + def string_to_object(self, key, value): + value = super().string_to_object(key, value) + if key == 'loss_function' and not isinstance(value, (list, tuple)): + try: + value = eval(value) + except Exception as e: + print( + "When loading loss_function from JSON, the following error occurred:'\n{}\n" + "To create a loss function object, its class name must be in the global namespace. You can " + "import the proper classes in your driver script using from ... import ..., and pass " + "globals() to load_from_json:\n" + " configs.load_from_json(filename, namespace=globals())\n".format(e) + ) + elif key == 'loss_tracker_params': + configs = LossTrackerParameters() + configs.deserizalize_dict(value) + value = configs + return value @dataclasses.dataclass class PretrainingConfig(TrainingConfig):