From 4e6d566cab7536c8761709217a487753c83d0529 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Sat, 11 May 2024 10:55:26 -0500 Subject: [PATCH] Allow running test data after each epoch --- generic_trainer/configs.py | 5 ++++ generic_trainer/trainer.py | 54 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/generic_trainer/configs.py b/generic_trainer/configs.py index d0cebe7..26aa5db 100644 --- a/generic_trainer/configs.py +++ b/generic_trainer/configs.py @@ -179,6 +179,11 @@ class TrainingConfig(Config): validation_dataset: Optional[Dataset] = None """The validation dataset. See the docstring of `training_dataset` for more details.""" + test_dataset: Optional[Dataset] = None + """ + The test dataset. It has no influence on training, just providing a way to check test performance after each epoch. + """ + batch_size_per_process: int = 64 """ The batch size per process. With this value denoted by `n_bspp`, the trainer behaves as the following: diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index 677dec9..2136d27 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -72,6 +72,7 @@ def __init__(self, pred_names=('cs', 'eg', 'sg'), *args, **kwargs): self['loss'] = [] self['val_loss'] = [] self['best_val_loss'] = np.inf + self['test_loss'] = [] self['lrs'] = [] self['epoch_best_val_loss'] = 0 self.current_epoch = 0 @@ -80,8 +81,10 @@ def __init__(self, pred_names=('cs', 'eg', 'sg'), *args, **kwargs): self['loss_{}'.format(pred_name)] = [] self['val_loss_{}'.format(pred_name)] = [] self['best_val_loss_{}'.format(pred_name)] = np.inf + self['test_loss_{}'.format(pred_name)] = [] self['train_acc_{}'.format(pred_name)] = [] self['val_acc_{}'.format(pred_name)] = [] + self['test_acc_{}'.format(pred_name)] = [] self['classification_preds_{}'.format(pred_name)] = [] self['classification_labels_{}'.format(pred_name)] = [] @@ -264,6 +267,7 @@ def __init__(self, configs: Union[TrainingConfig, Config], rank=None, num_proces self.dataset = self.configs.dataset self.training_dataset = None self.validation_dataset = None + self.test_dataset = None self.validation_ratio = self.configs.validation_ratio self.model = None self.model_params = None @@ -272,6 +276,8 @@ def __init__(self, configs: Union[TrainingConfig, Config], rank=None, num_proces self.training_dataloader = None self.validation_sampler = None self.validation_dataloader = None + self.test_sampler = None + self.test_dataloader = None self.num_local_devices = self.get_num_local_devices() self.num_processes = num_processes self.rank = rank @@ -460,6 +466,19 @@ def build_dataloaders(self): sampler=self.validation_sampler, collate_fn=lambda x: x ) + if self.configs.test_dataset is not None: + self.test_sampler = torch.utils.data.distributed.DistributedSampler( + self.test_dataset, + num_replicas=self.num_processes, + rank=self.rank, + drop_last=False + ) + self.test_dataloader = DistributedDataLoader( + self.test_dataset, + batch_size=self.configs.batch_size_per_process, + sampler=self.test_sampler, + collate_fn=lambda x: x + ) else: # ALCF documentation mentions that there is a bug in Pytorch's multithreaded data loaders with # distributed training across multiple nodes. Therefore, `num_workers` is set to 0. See also: @@ -474,6 +493,11 @@ def build_dataloaders(self): collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(), generator=self.get_dataloader_generator(), num_workers=0, drop_last=False) + self.test_dataloader = DataLoader(self.test_dataset, shuffle=True, + batch_size=self.all_proc_batch_size, + collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(), + generator=self.get_dataloader_generator(), num_workers=0, + drop_last=False) def run_training(self): for self.current_epoch in range(self.current_epoch, self.num_epochs): @@ -487,6 +511,9 @@ def run_training(self): self.model.eval() self.run_validation() + if self.test_dataset is not None: + self.run_test() + if self.verbose and self.rank == 0: self.loss_tracker.print_losses() self.write_training_info() @@ -694,6 +721,31 @@ def run_validation(self): if self.configs.post_validation_epoch_hook is not None: self.configs.post_validation_epoch_hook() + def run_test(self): + losses = self.get_epoch_loss_buffer() + n_batches = 0 + if self.configs.task_type == 'classification': + self.loss_tracker.clear_classification_results_and_labels() + for j, data_and_labels in enumerate(self.test_dataloader): + losses, _, preds, labels = self.load_data_and_get_loss(data_and_labels, losses) + if self.configs.task_type == 'classification': + self.loss_tracker.update_classification_results_and_labels(preds, labels) + n_batches += 1 + if n_batches == 0: + logging.warning('Test set might be too small that at least 1 rank did not get any test data.') + n_batches = np.max([n_batches, 1]) + + losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses] + self.loss_tracker.update_losses(losses, epoch=self.current_epoch, type='test_loss') + + if self.configs.task_type == 'classification': + self.loss_tracker.sync_classification_preds_and_labels_across_ranks() + acc_dict = self.loss_tracker.calculate_classification_accuracy() + self.loss_tracker.update_accuracy_history(acc_dict, 'test') + + if self.configs.post_validation_epoch_hook is not None: + self.configs.post_validation_epoch_hook() + def run_model_update_step(self, loss_node): self.optimizer.zero_grad() self.grad_scaler.scale(loss_node).backward() @@ -719,6 +771,8 @@ def build_split_datasets(self): logging.info('Training set size = {}; validation set size = {}.'.format( len(self.training_dataset), len(self.validation_dataset)) ) + if self.configs.test_dataset is not None: + self.test_dataset = self.configs.test_dataset def build_optimizer(self): if self.configs.pretrained_model_path is not None and self.configs.load_pretrained_encoder_only: