Skip to content

Commit

Permalink
Regularizer funcs are called first with kwargs then positional args
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Mar 4, 2024
1 parent dfd25ea commit 33cf694
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,10 @@ def compute_losses(self, loss_records, preds, labels):
pred_dict = self.get_pred_dict(preds)
for i in range(len(preds), len(self.loss_criterion)):
this_loss_func = self.loss_criterion[i]
this_loss_tensor = this_loss_func(**pred_dict)
try:
this_loss_tensor = this_loss_func(**pred_dict)
except TypeError:
this_loss_tensor = this_loss_func(*preds)
total_loss_tensor = total_loss_tensor + this_loss_tensor
loss_records[i + 1] += this_loss_tensor.detach().item()
loss_records[0] += total_loss_tensor.detach().item()
Expand Down

0 comments on commit 33cf694

Please sign in to comment.