Skip to content

Commit

Permalink
Add regression loss and error tracking.
Browse files Browse the repository at this point in the history
  • Loading branch information
Nina Andrejevic committed Jan 23, 2025
1 parent d0fdaeb commit 176b0f4
Showing 1 changed file with 103 additions and 20 deletions.
123 changes: 103 additions & 20 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.utils.data import Dataset, DataLoader, Subset

from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score

from tqdm import tqdm
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -158,7 +159,7 @@ def get_all_losses(self, type='loss'):

def get_metric_names_for_hyperparam_tuning(self):
"""
Get a list of keys taht can be used as the objective for hyperparameter tuning.
Get a list of keys that can be used as the objective for hyperparameter tuning.
:return: list[str].
"""
Expand Down Expand Up @@ -194,6 +195,11 @@ def clear_classification_results_and_labels(self):
self['classification_preds_{}'.format(pred_name)] = []
self['classification_labels_{}'.format(pred_name)] = []

def clear_regression_results_and_labels(self):
for pred_name in self.regr_pred_names:
self['regression_preds_{}'.format(pred_name)] = []
self['regression_labels_{}'.format(pred_name)] = []

def update_classification_results_and_labels(self, preds, labels):
"""
Update the classification results recorded for the current epoch with the predictions and labels of
Expand All @@ -213,6 +219,15 @@ def update_classification_results_and_labels(self, preds, labels):
self['classification_preds_{}'.format(pred_name)] += inds_pred.tolist()
self['classification_labels_{}'.format(pred_name)] += inds_label.tolist()

def update_regression_results_and_labels(self, preds, labels):
pred_dict = {}
label_dict = {}
for i, pred_name in enumerate(self.regr_pred_names):
pred_dict[pred_name] = preds[i]
label_dict[pred_name] = labels[i]
self['regression_preds_{}'.format(pred_name)] += preds[i].tolist()
self['regression_labels_{}'.format(pred_name)] += labels[i].tolist()

def calculate_classification_accuracy(self):
"""
Calculate classification accuracies at the end of an epoch using the recorded predictions and labels.
Expand All @@ -224,8 +239,20 @@ def calculate_classification_accuracy(self):
acc = np.mean((np.array(inds_pred) == np.array(inds_label)))
acc_dict[pred_name] = acc
return acc_dict

def calculate_regression_accuracy(self):
"""
Calculate regression accuracies at the end of an epoch using the recorded predictions and labels.
"""
acc_dict = {}
for i, pred_name in enumerate(self.regr_pred_names):
preds = self['regression_preds_{}'.format(pred_name)]
labels = self['regression_labels_{}'.format(pred_name)]
acc = r2_score(labels, preds, multioutput='uniform_average')
acc_dict[pred_name] = acc
return acc_dict

def update_accuracy_history(self, acc_dict, type='train'):
def update_classification_accuracy_history(self, acc_dict, type='train'):
"""
Update accuracy history.
Expand All @@ -238,6 +265,19 @@ def update_accuracy_history(self, acc_dict, type='train'):
self['{}_acc_{}'.format(type, pred_name)] = []
self['{}_acc_{}'.format(type, pred_name)].append(acc_dict[pred_name])

def update_regression_accuracy_history(self, acc_dict, type='train'):
"""
Update accuracy history.
:param acc_dict: dict. A dictionary where each key is in pred_names, and the corresponding value is
the accuracy of that catefory for all samples in the current epoch.
:param type: str. Can be 'train' or 'val'.
"""
for i, pred_name in enumerate(self.regr_pred_names):
if '{}_acc_{}'.format(type, pred_name) not in self.keys():
self['{}_acc_{}'.format(type, pred_name)] = []
self['{}_acc_{}'.format(type, pred_name)].append(acc_dict[pred_name])

def sync_classification_preds_and_labels_across_ranks(self):
if MPI is None:
return
Expand All @@ -252,6 +292,21 @@ def sync_classification_preds_and_labels_across_ranks(self):
assert isinstance(self['classification_labels_{}'.format(pred_name)], list)
self['classification_labels_{}'.format(pred_name)] = (
comm.allreduce(self['classification_labels_{}'.format(pred_name)], op=MPI.SUM))

def sync_regression_preds_and_labels_across_ranks(self):
if MPI is None:
return
comm = MPI.COMM_WORLD
n_ranks = comm.Get_size()
if n_ranks == 1:
return
for i, pred_name in enumerate(self.regr_pred_names):
assert isinstance(self['regression_preds_{}'.format(pred_name)], list)
self['regression_preds_{}'.format(pred_name)] = (
comm.allreduce(self['regression_preds_{}'.format(pred_name)], op=MPI.SUM))
assert isinstance(self['regression_labels_{}'.format(pred_name)], list)
self['regression_labels_{}'.format(pred_name)] = (
comm.allreduce(self['regression_labels_{}'.format(pred_name)], op=MPI.SUM))

def dump(self, path):
f = open(path, 'w')
Expand Down Expand Up @@ -307,6 +362,7 @@ def __init__(self, configs: Union[TrainingConfig, Config], rank=None, num_proces
self.num_processes = num_processes
self.rank = rank
self.device = self.get_device()
self.num_workers = self.configs.num_workers if self.configs.num_workers is not None else 0
self.all_proc_batch_size = self.configs.batch_size_per_process
self.learning_rate = self.configs.learning_rate_per_process
self.num_epochs = self.configs.num_epochs
Expand Down Expand Up @@ -509,20 +565,20 @@ def build_dataloaders(self):
# distributed training across multiple nodes. Therefore, `num_workers` is set to 0. See also:
# https://docs.alcf.anl.gov/polaris/data-science-workflows/frameworks/pytorch/.
self.training_dataloader = DataLoader(self.training_dataset, shuffle=True,
batch_size=self.all_proc_batch_size,
batch_size=self.all_proc_batch_size, prefetch_factor=10,
collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(),
generator=self.get_dataloader_generator(), num_workers=0,
generator=self.get_dataloader_generator(), num_workers=self.num_workers,
drop_last=False, pin_memory=self.configs.pin_memory_for_dataloader)
self.validation_dataloader = DataLoader(self.validation_dataset, shuffle=True,
batch_size=self.all_proc_batch_size,
batch_size=self.all_proc_batch_size, prefetch_factor=10,
collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(),
generator=self.get_dataloader_generator(), num_workers=0,
generator=self.get_dataloader_generator(), num_workers=self.num_workers,
drop_last=False, pin_memory=self.configs.pin_memory_for_dataloader)
if self.test_dataset is not None:
self.test_dataloader = DataLoader(self.test_dataset, shuffle=True,
batch_size=self.all_proc_batch_size,
batch_size=self.all_proc_batch_size, prefetch_factor=10,
collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(),
generator=self.get_dataloader_generator(), num_workers=0,
generator=self.get_dataloader_generator(), num_workers=self.num_workers,
drop_last=False, pin_memory=self.configs.pin_memory_for_dataloader)

def run_training(self):
Expand Down Expand Up @@ -692,12 +748,17 @@ def load_data_and_get_loss(self, data_and_labels, loss_buffer, *args, **kwargs):
def run_training_epoch(self):
losses = self.get_epoch_loss_buffer()
n_batches = 0
if self.configs.task_type == 'classification':
if 'classification' in self.configs.task_type:
self.loss_tracker.clear_classification_results_and_labels()
if 'regression' in self.configs.task_type:
self.loss_tracker.clear_regression_results_and_labels()

for i, data_and_labels in enumerate(tqdm(self.training_dataloader, disable=(not self.verbose))):
losses, total_loss_tensor, preds, labels = self.load_data_and_get_loss(data_and_labels, losses)
if self.configs.task_type == 'classification':
if 'classification' in self.configs.task_type:
self.loss_tracker.update_classification_results_and_labels(preds, labels)
if 'regression' in self.configs.task_type:
self.loss_tracker.update_regression_results_and_labels(preds, labels)

# Zero current grads and do backprop
self.run_model_update_step(total_loss_tensor)
Expand All @@ -714,23 +775,31 @@ def run_training_epoch(self):
losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses]
self.loss_tracker.update_losses(losses, type='loss', epoch=self.current_epoch)

if self.configs.task_type == 'classification':
if 'classification' in self.configs.task_type:
self.loss_tracker.sync_classification_preds_and_labels_across_ranks()
acc_dict = self.loss_tracker.calculate_classification_accuracy()
self.loss_tracker.update_accuracy_history(acc_dict, 'train')
self.loss_tracker.update_classification_accuracy_history(acc_dict, 'train')
if 'regression' in self.configs.task_type:
self.loss_tracker.sync_regression_preds_and_labels_across_ranks()
acc_dict = self.loss_tracker.calculate_regression_accuracy()
self.loss_tracker.update_regression_accuracy_history(acc_dict, 'train')

if self.configs.post_training_epoch_hook is not None:
self.configs.post_training_epoch_hook()

def run_validation(self):
losses = self.get_epoch_loss_buffer()
n_batches = 0
if self.configs.task_type == 'classification':
if 'classification' in self.configs.task_type:
self.loss_tracker.clear_classification_results_and_labels()
if 'regression' in self.configs.task_type:
self.loss_tracker.clear_regression_results_and_labels()
for j, data_and_labels in enumerate(self.validation_dataloader):
losses, _, preds, labels = self.load_data_and_get_loss(data_and_labels, losses)
if self.configs.task_type == 'classification':
if 'classification' in self.configs.task_type:
self.loss_tracker.update_classification_results_and_labels(preds, labels)
if 'regression' in self.configs.task_type:
self.loss_tracker.update_regression_results_and_labels(preds, labels)
n_batches += 1
if n_batches == 0:
logging.warning('Validation set might be too small that at least 1 rank did not get any validation data.')
Expand All @@ -746,23 +815,33 @@ def run_validation(self):
last_best_val_loss, self.loss_tracker['best_val_loss']))
self.update_saved_model(filename='best_model.pth', save_onnx=self.configs.save_onnx)

if self.configs.task_type == 'classification':
if 'classification' in self.configs.task_type:
self.loss_tracker.sync_classification_preds_and_labels_across_ranks()
acc_dict = self.loss_tracker.calculate_classification_accuracy()
self.loss_tracker.update_accuracy_history(acc_dict, 'val')
self.loss_tracker.update_classification_accuracy_history(acc_dict, 'val')
if 'regression' in self.configs.task_type:
self.loss_tracker.sync_regression_preds_and_labels_across_ranks()
acc_dict = self.loss_tracker.calculate_regression_accuracy()
self.loss_tracker.update_regression_accuracy_history(acc_dict, 'val')

if self.configs.post_validation_epoch_hook is not None:
self.configs.post_validation_epoch_hook()

def run_test(self):
losses = self.get_epoch_loss_buffer()
n_batches = 0
if self.configs.task_type == 'classification':
if 'classification' in self.configs.task_type:
self.loss_tracker.clear_classification_results_and_labels()
if 'regression' in self.configs.task_type:
self.loss_tracker.clear_regression_results_and_labels()

for j, data_and_labels in enumerate(self.test_dataloader):
losses, _, preds, labels = self.load_data_and_get_loss(data_and_labels, losses)
if self.configs.task_type == 'classification':
if 'classification' in self.configs.task_type:
self.loss_tracker.update_classification_results_and_labels(preds, labels)
if 'regression' in self.configs.task_type:
self.loss_tracker.update_regression_results_and_labels(preds, labels)

n_batches += 1
if n_batches == 0:
logging.warning('Test set might be too small that at least 1 rank did not get any test data.')
Expand All @@ -771,10 +850,14 @@ def run_test(self):
losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses]
self.loss_tracker.update_losses(losses, epoch=self.current_epoch, type='test_loss')

if self.configs.task_type == 'classification':
if 'classification' in self.configs.task_type:
self.loss_tracker.sync_classification_preds_and_labels_across_ranks()
acc_dict = self.loss_tracker.calculate_classification_accuracy()
self.loss_tracker.update_accuracy_history(acc_dict, 'test')
self.loss_tracker.update_classification_accuracy_history(acc_dict, 'test')
if 'regression' in self.configs.task_type:
self.loss_tracker.sync_regression_preds_and_labels_across_ranks()
acc_dict = self.loss_tracker.calculate_regression_accuracy()
self.loss_tracker.update_regression_accuracy_history(acc_dict, 'test')

if self.configs.post_test_epoch_hook is not None:
self.configs.post_test_epoch_hook()
Expand Down

0 comments on commit 176b0f4

Please sign in to comment.