From 33cf6944ad5e632847358e517ad17a096771dd5f Mon Sep 17 00:00:00 2001 From: Ming Du Date: Mon, 4 Mar 2024 15:09:40 -0600 Subject: [PATCH] Regularizer funcs are called first with kwargs then positional args --- generic_trainer/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index 51fccbb..85cedcc 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -430,7 +430,10 @@ def compute_losses(self, loss_records, preds, labels): pred_dict = self.get_pred_dict(preds) for i in range(len(preds), len(self.loss_criterion)): this_loss_func = self.loss_criterion[i] - this_loss_tensor = this_loss_func(**pred_dict) + try: + this_loss_tensor = this_loss_func(**pred_dict) + except TypeError: + this_loss_tensor = this_loss_func(*preds) total_loss_tensor = total_loss_tensor + this_loss_tensor loss_records[i + 1] += this_loss_tensor.detach().item() loss_records[0] += total_loss_tensor.detach().item()