Skip to content

Commit

Permalink
Fix bug in Pretrainer process_data_loader_yield
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Mar 30, 2024
1 parent 882c77b commit b9d9cb5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ def __init__(self, configs: TrainingConfig, rank=None, num_processes=None, *args

def process_data_loader_yield(self, data, **kwargs):
# All that the dataloader yield are supposed to be data. No label.
super().process_data_loader_yield(data, data_label_separation_index=None)
data = super().process_data_loader_yield(data, data_label_separation_index=None)
return data

def compute_losses(self, loss_records, preds, *args, **kwargs):
Expand Down Expand Up @@ -986,7 +986,7 @@ def compute_losses(self, loss_records, preds, *args, **kwargs):

def load_data_and_get_loss(self, data, loss_buffer):
# elements of data are supposed to be 2 different augmentations.
data = self.process_data_loader_yield(data)
data, _ = self.process_data_loader_yield(data)
preds = self.model(*data)
losses, total_loss_tensor = self.compute_losses(loss_buffer, preds)
return loss_buffer, total_loss_tensor, preds, None
Expand Down

0 comments on commit b9d9cb5

Please sign in to comment.