diff --git a/generic_trainer/configs.py b/generic_trainer/configs.py index 84105cc..8bda505 100644 --- a/generic_trainer/configs.py +++ b/generic_trainer/configs.py @@ -36,13 +36,7 @@ def get_serializable_dict(self): d = {} for key in self.__dict__.keys(): v = self.__dict__[key] - if not self.__class__.is_jsonable(v): - if isinstance(v, (tuple, list)): - v = [str(x) for x in v] - elif isinstance(v, OptionContainer): - v = v.get_serializable_dict() - else: - v = str(v) + v = self.object_to_string(key, v) d[key] = v return d @@ -62,7 +56,7 @@ def load_from_json(self, filename): f = open(filename, 'r') d = json.load(f) for key in d.keys(): - self.__dict__[key] = d[key] + self.__dict__[key] = self.string_to_object(key, d[key]) f.close() def string_to_object(self, key, value): @@ -72,15 +66,22 @@ def string_to_object(self, key, value): :param value: str. :return: object. """ - pass + return value - def object_to_string(self, key): + def object_to_string(self, key, value): """ Convert an object in a config key to string. :param key: str. + :param value: object. :return: str. """ - pass + if isinstance(value, OptionContainer): + value = value.get_serializable_dict() + elif isinstance(value, (tuple, list)): + value = [self.object_to_string(key, x) for x in value] + else: + value = str(value) + return value # =============================