From 41ba6430d4512edd5e09f68925efe5026b968790 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 5 Apr 2024 13:05:09 -0500 Subject: [PATCH] Tester bug fix --- generic_trainer/tester.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/generic_trainer/tester.py b/generic_trainer/tester.py index 3fd36d0..448ca5a 100644 --- a/generic_trainer/tester.py +++ b/generic_trainer/tester.py @@ -17,6 +17,7 @@ def __init__(self, configs: InferenceConfig): self.dataset = None self.sampler = None self.dataloader = None + self.parallelization_type = self.configs.parallelization_params.parallelization_type def build(self): self.build_ranks() @@ -35,7 +36,7 @@ def build_dataset(self): def build_dataloaders(self): if self.parallelization_type == 'multi_node': - self.sampler = torch.utils.data.distributed.DistributedSampler(self.training_dataset, + self.sampler = torch.utils.data.distributed.DistributedSampler(self.dataset, num_replicas=self.num_processes, rank=self.rank, drop_last=False, @@ -45,7 +46,7 @@ def build_dataloaders(self): sampler=self.sampler, collate_fn=lambda x: x) else: - self.dataloader = DataLoader(self.training_dataset, shuffle=False, + self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=self.all_proc_batch_size, collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(), generator=self.get_dataloader_generator(), num_workers=0,