diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index e3020db..9a4ea1c 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -592,8 +592,8 @@ def get_pred_dict(self, preds): :return: dict. """ d = {} - for i, name in enumerate(self.configs.pred_names_and_types): - d[name] = preds[i][0] + for i, name_and_type in enumerate(self.configs.pred_names_and_types): + d[name_and_type[0]] = preds[i] return d def process_data_loader_yield_sample_first(self, data_and_labels, data_label_separation_index):