diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index 201419e..fa34db2 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -646,7 +646,7 @@ def process_data_loader_yield(self, data_and_labels: Any, data_label_separation_ to be data. :returns tuple[torch.tensor], tuple[torch.tensor]. 2 tuples for data and labels. """ - if isinstance(data_and_labels[0], tuple): + if isinstance(data_and_labels[0], (tuple, list)): # In this case, data_and_labels is in sample-then-item order. data_list, label_list = self.process_data_loader_yield_sample_first(data_and_labels, data_label_separation_index)