diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index 601fee2..9f2987b 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -5,6 +5,7 @@ import copy import re import warnings +from contextlib import ExitStack try: from mpi4py import MPI @@ -170,7 +171,7 @@ def clear_classification_results_and_labels(self): def update_classification_results_and_labels(self, preds, labels): """ Update the classification results recorded for the current epoch with the predictions and labels of - the current iteration. + the current iteration.testtr :param preds: list[torch.tensor]. Each tensor should be of shape [n_batch, n_classes]. :param labels: list[torch.tensor]. Each tensor should be of shape [n_batch, n_classes]. @@ -610,7 +611,11 @@ def load_data_and_get_loss(self, data_and_labels, loss_buffer, *args, **kwargs): :return: loss_buffer, total_loss_tensor, preds, labels """ data, labels = self.process_data_loader_yield(data_and_labels) - with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.use_torch_amp): + with ExitStack() as es: + # torch.autocast would raise an exception about CPU data type even when enabled == False, so + # we use ExitStack to control the entrance of this context on a higher level. + if self.use_torch_amp: + es.enter_context(torch.autocast(device_type=self.device.type, dtype=torch.float16)) preds = self.model(*data) losses, total_loss_tensor = self.compute_losses(loss_buffer, preds, labels) return loss_buffer, total_loss_tensor, preds, labels