Skip to content

Commit

Permalink
Removed AlphaDiffract-specific code in LossTracker
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Mar 1, 2024
1 parent a1acba5 commit dfd25ea
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 46 deletions.
2 changes: 2 additions & 0 deletions generic_trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from generic_trainer.trainer import *
import generic_trainer.metrics
47 changes: 1 addition & 46 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

class LossTracker(dict):

def __init__(self, pred_names=('cs', 'eg', 'sg'), require_cs_labels=False, *args, **kwargs):
def __init__(self, pred_names=('cs', 'eg', 'sg'), *args, **kwargs):
"""
A dictionary-like object that stores the values of losses.
Expand All @@ -37,14 +37,9 @@ def __init__(self, pred_names=('cs', 'eg', 'sg'), require_cs_labels=False, *args
which are updated every opech.*
:param pred_names: tuple(str). Names of predicted quantities.
:param require_cs_labels: bool. When True, the loss tracker will expect CS labels as the last element in the
label list when calculating accuracies. This is usually
used for calculating the accuracies of CS deduced from SG predictions, without
an actual CS classification head.
"""
super().__init__(*args, **kwargs)
self.pred_names = pred_names
self.require_cs_labels = require_cs_labels
self.n_preds = len(pred_names)
self['epochs'] = []
self['loss'] = []
Expand All @@ -62,13 +57,6 @@ def __init__(self, pred_names=('cs', 'eg', 'sg'), require_cs_labels=False, *args
self['val_acc_{}'.format(pred_name)] = []
self['classification_preds_{}'.format(pred_name)] = []
self['classification_labels_{}'.format(pred_name)] = []
# Also calculate the CS accuracy deduced from SG predictions.
if 'sg' in self.pred_names:
self['classification_preds_cs_from_sg'] = []
self['train_acc_cs_from_sg'] = []
self['val_acc_cs_from_sg'] = []
if self.require_cs_labels and ('cs' not in self.pred_names):
self['classification_labels_cs'] = []

def update_losses(self, losses, type='loss', epoch=None, lr=None):
"""
Expand Down Expand Up @@ -152,10 +140,6 @@ def clear_classification_results_and_labels(self):
for pred_name in self.pred_names:
self['classification_preds_{}'.format(pred_name)] = []
self['classification_labels_{}'.format(pred_name)] = []
if 'classification_preds_cs_from_sg' in self.keys():
self['classification_preds_cs_from_sg'] = []
if self.require_cs_labels:
self['classification_labels_cs'] = []

def update_classification_results_and_labels(self, preds, labels):
"""
Expand All @@ -175,28 +159,6 @@ def update_classification_results_and_labels(self, preds, labels):
label_dict[pred_name] = inds_label
self['classification_preds_{}'.format(pred_name)] += inds_pred.tolist()
self['classification_labels_{}'.format(pred_name)] += inds_label.tolist()
if 'classification_preds_cs_from_sg' in self.keys():
inds_cs_from_sg = self.get_cs_from_sg_predictions(pred_dict['sg'])
self['classification_preds_cs_from_sg'] += inds_cs_from_sg.tolist()
if self.require_cs_labels and 'cs' not in self.pred_names:
assert len(labels) > len(self.pred_names), ('require_cs_labels is True, so I am expecting CS labels even '
'pred_names does not include it. Make sure dataset returns '
'CS labels at last. ')
inds_label = torch.argmax(labels[-1], dim=1)
self['classification_labels_cs'] += inds_label.tolist()

def get_cs_from_sg_predictions(self, inds_sg):
"""
Get the predicted classes for CS from the predictions of SG using their hierarchical relation.
:param inds_sg: torch.tensor. The predicted DG indices (not one-hot).
:return: torch.tensor. 1D tensor of predicted CS indices.
"""
sg_group_starting_inds = torch.tensor(consts.cs_to_sg_index_bracket_array[:, 0], device=inds_sg.device)
diff_array = inds_sg.view(-1, 1) - sg_group_starting_inds
diff_array[diff_array < 0] = diff_array.max() + 1
cs_inds = torch.argmin(diff_array, dim=1)
return cs_inds

def calculate_classification_accuracy(self):
"""
Expand All @@ -208,11 +170,6 @@ def calculate_classification_accuracy(self):
inds_label = self['classification_labels_{}'.format(pred_name)]
acc = np.mean((np.array(inds_pred) == np.array(inds_label)))
acc_dict[pred_name] = acc
if 'classification_preds_cs_from_sg' in self.keys():
inds_pred = self['classification_preds_cs_from_sg']
inds_label = self['classification_labels_cs']
acc = np.mean((np.array(inds_pred) == np.array(inds_label)))
acc_dict['cs_from_sg'] = acc
return acc_dict

def update_accuracy_history(self, acc_dict, type='train'):
Expand All @@ -225,8 +182,6 @@ def update_accuracy_history(self, acc_dict, type='train'):
"""
for i, pred_name in enumerate(self.pred_names):
self['{}_acc_{}'.format(type, pred_name)].append(acc_dict[pred_name])
if 'cs_from_sg' in acc_dict.keys():
self['{}_acc_cs_from_sg'.format(type)].append(acc_dict['cs_from_sg'])


class Trainer:
Expand Down

0 comments on commit dfd25ea

Please sign in to comment.