Skip to content

Commit

Permalink
Control torch.autocast entrance using ExitStack
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Apr 11, 2024
1 parent 62417a3 commit 948c66b
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import re
import warnings
from contextlib import ExitStack

try:
from mpi4py import MPI
Expand Down Expand Up @@ -170,7 +171,7 @@ def clear_classification_results_and_labels(self):
def update_classification_results_and_labels(self, preds, labels):
"""
Update the classification results recorded for the current epoch with the predictions and labels of
the current iteration.
the current iteration.testtr
:param preds: list[torch.tensor]. Each tensor should be of shape [n_batch, n_classes].
:param labels: list[torch.tensor]. Each tensor should be of shape [n_batch, n_classes].
Expand Down Expand Up @@ -610,7 +611,11 @@ def load_data_and_get_loss(self, data_and_labels, loss_buffer, *args, **kwargs):
:return: loss_buffer, total_loss_tensor, preds, labels
"""
data, labels = self.process_data_loader_yield(data_and_labels)
with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.use_torch_amp):
with ExitStack() as es:
# torch.autocast would raise an exception about CPU data type even when enabled == False, so
# we use ExitStack to control the entrance of this context on a higher level.
if self.use_torch_amp:
es.enter_context(torch.autocast(device_type=self.device.type, dtype=torch.float16))
preds = self.model(*data)
losses, total_loss_tensor = self.compute_losses(loss_buffer, preds, labels)
return loss_buffer, total_loss_tensor, preds, labels
Expand Down

0 comments on commit 948c66b

Please sign in to comment.