From 176b0f44f737d478ef6cf3a6c8bee1e557a547cf Mon Sep 17 00:00:00 2001 From: Nina Andrejevic Date: Thu, 23 Jan 2025 12:57:01 -0600 Subject: [PATCH] Add regression loss and error tracking. --- generic_trainer/trainer.py | 123 +++++++++++++++++++++++++++++++------ 1 file changed, 103 insertions(+), 20 deletions(-) diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index fa34db2..a3b439f 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -18,6 +18,7 @@ from torch.utils.data import Dataset, DataLoader, Subset from sklearn.model_selection import train_test_split +from sklearn.metrics import r2_score from tqdm import tqdm import matplotlib.pyplot as plt @@ -158,7 +159,7 @@ def get_all_losses(self, type='loss'): def get_metric_names_for_hyperparam_tuning(self): """ - Get a list of keys taht can be used as the objective for hyperparameter tuning. + Get a list of keys that can be used as the objective for hyperparameter tuning. :return: list[str]. """ @@ -194,6 +195,11 @@ def clear_classification_results_and_labels(self): self['classification_preds_{}'.format(pred_name)] = [] self['classification_labels_{}'.format(pred_name)] = [] + def clear_regression_results_and_labels(self): + for pred_name in self.regr_pred_names: + self['regression_preds_{}'.format(pred_name)] = [] + self['regression_labels_{}'.format(pred_name)] = [] + def update_classification_results_and_labels(self, preds, labels): """ Update the classification results recorded for the current epoch with the predictions and labels of @@ -213,6 +219,15 @@ def update_classification_results_and_labels(self, preds, labels): self['classification_preds_{}'.format(pred_name)] += inds_pred.tolist() self['classification_labels_{}'.format(pred_name)] += inds_label.tolist() + def update_regression_results_and_labels(self, preds, labels): + pred_dict = {} + label_dict = {} + for i, pred_name in enumerate(self.regr_pred_names): + pred_dict[pred_name] = preds[i] + label_dict[pred_name] = labels[i] + self['regression_preds_{}'.format(pred_name)] += preds[i].tolist() + self['regression_labels_{}'.format(pred_name)] += labels[i].tolist() + def calculate_classification_accuracy(self): """ Calculate classification accuracies at the end of an epoch using the recorded predictions and labels. @@ -224,8 +239,20 @@ def calculate_classification_accuracy(self): acc = np.mean((np.array(inds_pred) == np.array(inds_label))) acc_dict[pred_name] = acc return acc_dict + + def calculate_regression_accuracy(self): + """ + Calculate regression accuracies at the end of an epoch using the recorded predictions and labels. + """ + acc_dict = {} + for i, pred_name in enumerate(self.regr_pred_names): + preds = self['regression_preds_{}'.format(pred_name)] + labels = self['regression_labels_{}'.format(pred_name)] + acc = r2_score(labels, preds, multioutput='uniform_average') + acc_dict[pred_name] = acc + return acc_dict - def update_accuracy_history(self, acc_dict, type='train'): + def update_classification_accuracy_history(self, acc_dict, type='train'): """ Update accuracy history. @@ -238,6 +265,19 @@ def update_accuracy_history(self, acc_dict, type='train'): self['{}_acc_{}'.format(type, pred_name)] = [] self['{}_acc_{}'.format(type, pred_name)].append(acc_dict[pred_name]) + def update_regression_accuracy_history(self, acc_dict, type='train'): + """ + Update accuracy history. + + :param acc_dict: dict. A dictionary where each key is in pred_names, and the corresponding value is + the accuracy of that catefory for all samples in the current epoch. + :param type: str. Can be 'train' or 'val'. + """ + for i, pred_name in enumerate(self.regr_pred_names): + if '{}_acc_{}'.format(type, pred_name) not in self.keys(): + self['{}_acc_{}'.format(type, pred_name)] = [] + self['{}_acc_{}'.format(type, pred_name)].append(acc_dict[pred_name]) + def sync_classification_preds_and_labels_across_ranks(self): if MPI is None: return @@ -252,6 +292,21 @@ def sync_classification_preds_and_labels_across_ranks(self): assert isinstance(self['classification_labels_{}'.format(pred_name)], list) self['classification_labels_{}'.format(pred_name)] = ( comm.allreduce(self['classification_labels_{}'.format(pred_name)], op=MPI.SUM)) + + def sync_regression_preds_and_labels_across_ranks(self): + if MPI is None: + return + comm = MPI.COMM_WORLD + n_ranks = comm.Get_size() + if n_ranks == 1: + return + for i, pred_name in enumerate(self.regr_pred_names): + assert isinstance(self['regression_preds_{}'.format(pred_name)], list) + self['regression_preds_{}'.format(pred_name)] = ( + comm.allreduce(self['regression_preds_{}'.format(pred_name)], op=MPI.SUM)) + assert isinstance(self['regression_labels_{}'.format(pred_name)], list) + self['regression_labels_{}'.format(pred_name)] = ( + comm.allreduce(self['regression_labels_{}'.format(pred_name)], op=MPI.SUM)) def dump(self, path): f = open(path, 'w') @@ -307,6 +362,7 @@ def __init__(self, configs: Union[TrainingConfig, Config], rank=None, num_proces self.num_processes = num_processes self.rank = rank self.device = self.get_device() + self.num_workers = self.configs.num_workers if self.configs.num_workers is not None else 0 self.all_proc_batch_size = self.configs.batch_size_per_process self.learning_rate = self.configs.learning_rate_per_process self.num_epochs = self.configs.num_epochs @@ -509,20 +565,20 @@ def build_dataloaders(self): # distributed training across multiple nodes. Therefore, `num_workers` is set to 0. See also: # https://docs.alcf.anl.gov/polaris/data-science-workflows/frameworks/pytorch/. self.training_dataloader = DataLoader(self.training_dataset, shuffle=True, - batch_size=self.all_proc_batch_size, + batch_size=self.all_proc_batch_size, prefetch_factor=10, collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(), - generator=self.get_dataloader_generator(), num_workers=0, + generator=self.get_dataloader_generator(), num_workers=self.num_workers, drop_last=False, pin_memory=self.configs.pin_memory_for_dataloader) self.validation_dataloader = DataLoader(self.validation_dataset, shuffle=True, - batch_size=self.all_proc_batch_size, + batch_size=self.all_proc_batch_size, prefetch_factor=10, collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(), - generator=self.get_dataloader_generator(), num_workers=0, + generator=self.get_dataloader_generator(), num_workers=self.num_workers, drop_last=False, pin_memory=self.configs.pin_memory_for_dataloader) if self.test_dataset is not None: self.test_dataloader = DataLoader(self.test_dataset, shuffle=True, - batch_size=self.all_proc_batch_size, + batch_size=self.all_proc_batch_size, prefetch_factor=10, collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(), - generator=self.get_dataloader_generator(), num_workers=0, + generator=self.get_dataloader_generator(), num_workers=self.num_workers, drop_last=False, pin_memory=self.configs.pin_memory_for_dataloader) def run_training(self): @@ -692,12 +748,17 @@ def load_data_and_get_loss(self, data_and_labels, loss_buffer, *args, **kwargs): def run_training_epoch(self): losses = self.get_epoch_loss_buffer() n_batches = 0 - if self.configs.task_type == 'classification': + if 'classification' in self.configs.task_type: self.loss_tracker.clear_classification_results_and_labels() + if 'regression' in self.configs.task_type: + self.loss_tracker.clear_regression_results_and_labels() + for i, data_and_labels in enumerate(tqdm(self.training_dataloader, disable=(not self.verbose))): losses, total_loss_tensor, preds, labels = self.load_data_and_get_loss(data_and_labels, losses) - if self.configs.task_type == 'classification': + if 'classification' in self.configs.task_type: self.loss_tracker.update_classification_results_and_labels(preds, labels) + if 'regression' in self.configs.task_type: + self.loss_tracker.update_regression_results_and_labels(preds, labels) # Zero current grads and do backprop self.run_model_update_step(total_loss_tensor) @@ -714,10 +775,14 @@ def run_training_epoch(self): losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses] self.loss_tracker.update_losses(losses, type='loss', epoch=self.current_epoch) - if self.configs.task_type == 'classification': + if 'classification' in self.configs.task_type: self.loss_tracker.sync_classification_preds_and_labels_across_ranks() acc_dict = self.loss_tracker.calculate_classification_accuracy() - self.loss_tracker.update_accuracy_history(acc_dict, 'train') + self.loss_tracker.update_classification_accuracy_history(acc_dict, 'train') + if 'regression' in self.configs.task_type: + self.loss_tracker.sync_regression_preds_and_labels_across_ranks() + acc_dict = self.loss_tracker.calculate_regression_accuracy() + self.loss_tracker.update_regression_accuracy_history(acc_dict, 'train') if self.configs.post_training_epoch_hook is not None: self.configs.post_training_epoch_hook() @@ -725,12 +790,16 @@ def run_training_epoch(self): def run_validation(self): losses = self.get_epoch_loss_buffer() n_batches = 0 - if self.configs.task_type == 'classification': + if 'classification' in self.configs.task_type: self.loss_tracker.clear_classification_results_and_labels() + if 'regression' in self.configs.task_type: + self.loss_tracker.clear_regression_results_and_labels() for j, data_and_labels in enumerate(self.validation_dataloader): losses, _, preds, labels = self.load_data_and_get_loss(data_and_labels, losses) - if self.configs.task_type == 'classification': + if 'classification' in self.configs.task_type: self.loss_tracker.update_classification_results_and_labels(preds, labels) + if 'regression' in self.configs.task_type: + self.loss_tracker.update_regression_results_and_labels(preds, labels) n_batches += 1 if n_batches == 0: logging.warning('Validation set might be too small that at least 1 rank did not get any validation data.') @@ -746,10 +815,14 @@ def run_validation(self): last_best_val_loss, self.loss_tracker['best_val_loss'])) self.update_saved_model(filename='best_model.pth', save_onnx=self.configs.save_onnx) - if self.configs.task_type == 'classification': + if 'classification' in self.configs.task_type: self.loss_tracker.sync_classification_preds_and_labels_across_ranks() acc_dict = self.loss_tracker.calculate_classification_accuracy() - self.loss_tracker.update_accuracy_history(acc_dict, 'val') + self.loss_tracker.update_classification_accuracy_history(acc_dict, 'val') + if 'regression' in self.configs.task_type: + self.loss_tracker.sync_regression_preds_and_labels_across_ranks() + acc_dict = self.loss_tracker.calculate_regression_accuracy() + self.loss_tracker.update_regression_accuracy_history(acc_dict, 'val') if self.configs.post_validation_epoch_hook is not None: self.configs.post_validation_epoch_hook() @@ -757,12 +830,18 @@ def run_validation(self): def run_test(self): losses = self.get_epoch_loss_buffer() n_batches = 0 - if self.configs.task_type == 'classification': + if 'classification' in self.configs.task_type: self.loss_tracker.clear_classification_results_and_labels() + if 'regression' in self.configs.task_type: + self.loss_tracker.clear_regression_results_and_labels() + for j, data_and_labels in enumerate(self.test_dataloader): losses, _, preds, labels = self.load_data_and_get_loss(data_and_labels, losses) - if self.configs.task_type == 'classification': + if 'classification' in self.configs.task_type: self.loss_tracker.update_classification_results_and_labels(preds, labels) + if 'regression' in self.configs.task_type: + self.loss_tracker.update_regression_results_and_labels(preds, labels) + n_batches += 1 if n_batches == 0: logging.warning('Test set might be too small that at least 1 rank did not get any test data.') @@ -771,10 +850,14 @@ def run_test(self): losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses] self.loss_tracker.update_losses(losses, epoch=self.current_epoch, type='test_loss') - if self.configs.task_type == 'classification': + if 'classification' in self.configs.task_type: self.loss_tracker.sync_classification_preds_and_labels_across_ranks() acc_dict = self.loss_tracker.calculate_classification_accuracy() - self.loss_tracker.update_accuracy_history(acc_dict, 'test') + self.loss_tracker.update_classification_accuracy_history(acc_dict, 'test') + if 'regression' in self.configs.task_type: + self.loss_tracker.sync_regression_preds_and_labels_across_ranks() + acc_dict = self.loss_tracker.calculate_regression_accuracy() + self.loss_tracker.update_regression_accuracy_history(acc_dict, 'test') if self.configs.post_test_epoch_hook is not None: self.configs.post_test_epoch_hook()