Skip to content

Commit

Permalink
Remove field decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Jul 17, 2024
1 parent 06a6c35 commit 9c6923f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
21 changes: 21 additions & 0 deletions generic_trainer/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,27 @@

from generic_trainer.metrics import *


def remove(*fields):
"""
A decorator that removes specified fields from a dataclass. Note: it may not work for multiple inheritance.
"""
def _(cls):
print(cls)
print(list(cls.__dataclass_fields__.keys()))
fields_copy = copy.copy(cls.__dataclass_fields__)
annotations_copy = copy.deepcopy(cls.__annotations__)
for field in fields:
try:
del fields_copy[field]
del annotations_copy[field]
except KeyError:
pass
d_cls = dataclasses.make_dataclass(cls.__name__, annotations_copy)
d_cls.__dataclass_fields__ = fields_copy
return d_cls
return _

# =============================
# Base class for all
# =============================
Expand Down
3 changes: 1 addition & 2 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def __init__(self, pred_names_and_types=(('cs', 'cls'), ('eg', 'cls'), ('sg', 'c
self['classification_labels_{}'.format(pred_name)] = []

def categorize_predictions(self):
assert (len(self.pred_names_and_types[0]) > 1,
'Prediction names and types should be both given.')
assert len(self.pred_names_and_types[0]) > 1, 'Prediction names and types should be both given.'
for x in self.pred_names_and_types:
self.pred_names.append(x[0])
if x[1] == 'cls':
Expand Down

0 comments on commit 9c6923f

Please sign in to comment.