Skip to content

Commit

Permalink
Go back to use DistributedSampler for multi-node mode after fixing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Mar 29, 2024
1 parent 31a3208 commit 882c77b
Showing 1 changed file with 37 additions and 44 deletions.
81 changes: 37 additions & 44 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,26 +411,37 @@ def build_scalable_parameters(self):
self.learning_rate = self.configs.learning_rate_per_process * self.num_processes

def build_dataloaders(self):
drop_last = False
# Need double check on this.
if self.parallelization_type == 'multi_node':
# PyTorch documentation recommends the use of DistributedSampler for DDP, but I found it yields
# identical data for all ranks for some reason, so we just use an ordinary data loader and distribute
# data to ranks in process_data_loader_yield().
drop_last = True
# ALCF documentation mentions that there is a bug in Pytorch's multithreaded data loaders with
# distributed training across multiple nodes. Therefore, `num_workers` is set to 0. See also:
# https://docs.alcf.anl.gov/polaris/data-science-workflows/frameworks/pytorch/.
self.training_dataloader = DataLoader(self.training_dataset, shuffle=True,
batch_size=self.all_proc_batch_size,
collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(),
generator=self.get_dataloader_generator(), num_workers=0,
drop_last=drop_last)
self.validation_dataloader = DataLoader(self.validation_dataset, shuffle=True,
batch_size=self.all_proc_batch_size,
collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(),
generator=self.get_dataloader_generator(), num_workers=0,
drop_last=drop_last)
training_sampler = torch.utils.data.distributed.DistributedSampler(self.training_dataset,
num_replicas=self.num_processes,
rank=self.rank,
drop_last=True)
self.training_dataloader = DataLoader(self.training_dataset,
batch_size=self.configs.batch_size_per_process,
sampler=training_sampler,
collate_fn=lambda x: x)
validation_sampler = torch.utils.data.distributed.DistributedSampler(self.validation_dataset,
num_replicas=self.num_processes,
rank=self.rank,
drop_last=True)
self.validation_dataloader = DataLoader(self.validation_dataset,
batch_size=self.all_proc_batch_size,
sampler=validation_sampler,
collate_fn=lambda x: x)
else:
# ALCF documentation mentions that there is a bug in Pytorch's multithreaded data loaders with
# distributed training across multiple nodes. Therefore, `num_workers` is set to 0. See also:
# https://docs.alcf.anl.gov/polaris/data-science-workflows/frameworks/pytorch/.
self.training_dataloader = DataLoader(self.training_dataset, shuffle=True,
batch_size=self.all_proc_batch_size,
collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(),
generator=self.get_dataloader_generator(), num_workers=0,
drop_last=False)
self.validation_dataloader = DataLoader(self.validation_dataset, shuffle=True,
batch_size=self.all_proc_batch_size,
collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(),
generator=self.get_dataloader_generator(), num_workers=0,
drop_last=False)

def run_training(self):
for self.current_epoch in range(self.current_epoch, self.num_epochs):
Expand Down Expand Up @@ -546,32 +557,14 @@ def process_data_loader_yield(self, data_and_labels: Any, data_label_separation_
to be data.
:returns tuple[torch.tensor], tuple[torch.tensor]. 2 tuples for data and labels.
"""
if self.parallelization_type == 'multi_node':
bsize_per_rank = self.configs.batch_size_per_process
if isinstance(data_and_labels[0], tuple):
# In this case, data_and_labels is organized in a sample-then-item order:
# it is a tuple of samples. Each element of the tuple is another tuple
# containing the data and labels of that sample.
data_and_labels = data_and_labels[self.rank * bsize_per_rank:(self.rank + 1) * bsize_per_rank]
data_list, label_list = self.process_data_loader_yield_sample_first(data_and_labels,
data_label_separation_index)
else:
# In this case, data_and_labels is organized in a item-then-sample order:
# it is a tuple of items. Each element of the tuple is a tensor of
# [n_total_batch_size, ...].
data_and_labels = [d[self.rank * bsize_per_rank:(self.rank + 1) * bsize_per_rank]
for d in data_and_labels]
data_list, label_list = self.process_data_loader_yield_item_first(data_and_labels,
data_label_separation_index)
if isinstance(data_and_labels[0], tuple):
# In this case, data_and_labels is in sample-then-item order.
data_list, label_list = self.process_data_loader_yield_sample_first(data_and_labels,
data_label_separation_index)
else:
if isinstance(data_and_labels[0], tuple):
# In this case, data_and_labels is in sample-then-item order.
data_list, label_list = self.process_data_loader_yield_sample_first(data_and_labels,
data_label_separation_index)
else:
# In this case, data_and_labels is in item-then-sample order.
data_list, label_list = self.process_data_loader_yield_item_first(data_and_labels,
data_label_separation_index)
# In this case, data_and_labels is in item-then-sample order.
data_list, label_list = self.process_data_loader_yield_item_first(data_and_labels,
data_label_separation_index)
return data_list, label_list

def get_epoch_loss_buffer(self):
Expand Down

0 comments on commit 882c77b

Please sign in to comment.