Skip to content

Commit

Permalink
Tester bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Ming Du committed Apr 5, 2024
1 parent c05a8fd commit 41ba643
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions generic_trainer/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, configs: InferenceConfig):
self.dataset = None
self.sampler = None
self.dataloader = None
self.parallelization_type = self.configs.parallelization_params.parallelization_type

def build(self):
self.build_ranks()
Expand All @@ -35,7 +36,7 @@ def build_dataset(self):

def build_dataloaders(self):
if self.parallelization_type == 'multi_node':
self.sampler = torch.utils.data.distributed.DistributedSampler(self.training_dataset,
self.sampler = torch.utils.data.distributed.DistributedSampler(self.dataset,
num_replicas=self.num_processes,
rank=self.rank,
drop_last=False,
Expand All @@ -45,7 +46,7 @@ def build_dataloaders(self):
sampler=self.sampler,
collate_fn=lambda x: x)
else:
self.dataloader = DataLoader(self.training_dataset, shuffle=False,
self.dataloader = DataLoader(self.dataset, shuffle=False,
batch_size=self.all_proc_batch_size,
collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(),
generator=self.get_dataloader_generator(), num_workers=0,
Expand Down

0 comments on commit 41ba643

Please sign in to comment.