From 9c6923f93b2ded1a68bbbd053e2d090a5f6521c9 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Wed, 17 Jul 2024 15:39:25 -0500 Subject: [PATCH] Remove field decorator --- generic_trainer/configs.py | 21 +++++++++++++++++++++ generic_trainer/trainer.py | 3 +-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/generic_trainer/configs.py b/generic_trainer/configs.py index dbbdb40..0e940f6 100644 --- a/generic_trainer/configs.py +++ b/generic_trainer/configs.py @@ -12,6 +12,27 @@ from generic_trainer.metrics import * + +def remove(*fields): + """ + A decorator that removes specified fields from a dataclass. Note: it may not work for multiple inheritance. + """ + def _(cls): + print(cls) + print(list(cls.__dataclass_fields__.keys())) + fields_copy = copy.copy(cls.__dataclass_fields__) + annotations_copy = copy.deepcopy(cls.__annotations__) + for field in fields: + try: + del fields_copy[field] + del annotations_copy[field] + except KeyError: + pass + d_cls = dataclasses.make_dataclass(cls.__name__, annotations_copy) + d_cls.__dataclass_fields__ = fields_copy + return d_cls + return _ + # ============================= # Base class for all # ============================= diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index 541fe80..0065f94 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -97,8 +97,7 @@ def __init__(self, pred_names_and_types=(('cs', 'cls'), ('eg', 'cls'), ('sg', 'c self['classification_labels_{}'.format(pred_name)] = [] def categorize_predictions(self): - assert (len(self.pred_names_and_types[0]) > 1, - 'Prediction names and types should be both given.') + assert len(self.pred_names_and_types[0]) > 1, 'Prediction names and types should be both given.' for x in self.pred_names_and_types: self.pred_names.append(x[0]) if x[1] == 'cls':