Skip to content

Commit

Permalink
Home-made dataloader for HPC training
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Apr 11, 2024
1 parent 36493df commit 0503b22
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 16 deletions.
72 changes: 72 additions & 0 deletions generic_trainer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch


class DistributedDataLoader:
"""
A dataloader for distributed training on HPCs.
When using the built-in DataLoader from PyTorch on HPCs (mainly Polaris), we noticed that it took long to finish
loading a batch (by calling dataloader.__iter__().__next__()), even though the __getitem__ calls for all samples
in the batch would take just a fraction of that time. ALCF documentation pointed out a bug that requires setting
num_workers = 0, but that didn't help. Also, we found the built-in dataloader not using the __getitems__ method
in the dataset even when it is defined. So we created this minimalistic dataloader with only essential operations
(mainly dataset.__getitems__) to work around that.
"""
def __init__(self, dataset, batch_size, sampler, collate_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.sampler = sampler
self.i_sample = 0
self.i_batch = 0
self.sampler_iter = iter(self.sampler)
self.iter = self.__iter__()

def __len__(self):
return len(self.sampler) // self.batch_size

def __iter__(self):
return DistributedDataLoaderIterator(self.dataset, self.batch_size, self.sampler)

def __next__(self):
data = next(self.iter)
return data


class DistributedDataLoaderIterator:

def __init__(self, dataset, batch_size, sampler, collate_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.sampler = sampler
self.i_sample = 0
self.i_batch = 0
self.sampler_iter = iter(self.sampler)

def reset(self):
self.i_sample = 0
self.i_batch = 0

def __len__(self):
return len(self.sampler) // self.batch_size

def __next__(self):
if self.i_batch >= self.__len__():
self.reset()
raise StopIteration

# Get indices for this batch
inds = []
for i in range(self.batch_size):
inds.append(next(self.sampler_iter))
inds = tuple(inds)

self.i_sample += self.batch_size
self.i_batch += 1
try:
data = self.dataset.__getitems__(inds)
except AttributeError:
raw_data = [self.dataset[i] for i in inds]
data = []
for i in range(len(raw_data[0])):
data.append(torch.cat([raw_data[j][i] for j in range(len(raw_data))]))
return data
41 changes: 25 additions & 16 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .configs import *
from .util import *
from .compat import *
from .data import *


class MultirankGateKeeper:
Expand Down Expand Up @@ -420,22 +421,30 @@ def build_scalable_parameters(self):

def build_dataloaders(self):
if self.parallelization_type == 'multi_node':
self.training_sampler = torch.utils.data.distributed.DistributedSampler(self.training_dataset,
num_replicas=self.num_processes,
rank=self.rank,
drop_last=True)
self.training_dataloader = DataLoader(self.training_dataset,
batch_size=self.configs.batch_size_per_process,
sampler=self.training_sampler,
collate_fn=lambda x: x)
self.validation_sampler = torch.utils.data.distributed.DistributedSampler(self.validation_dataset,
num_replicas=self.num_processes,
rank=self.rank,
drop_last=True)
self.validation_dataloader = DataLoader(self.validation_dataset,
batch_size=self.all_proc_batch_size,
sampler=self.validation_sampler,
collate_fn=lambda x: x)
self.training_sampler = torch.utils.data.distributed.DistributedSampler(
self.training_dataset,
num_replicas=self.num_processes,
rank=self.rank,
drop_last=True
)
self.training_dataloader = DistributedDataLoader(
self.training_dataset,
batch_size=self.configs.batch_size_per_process,
sampler=self.training_sampler,
collate_fn=lambda x: x
)
self.validation_sampler = torch.utils.data.distributed.DistributedSampler(
self.validation_dataset,
num_replicas=self.num_processes,
rank=self.rank,
drop_last=True
)
self.validation_dataloader = DistributedDataLoader(
self.validation_dataset,
batch_size=self.configs.batch_size_per_process,
sampler=self.validation_sampler,
collate_fn=lambda x: x
)
else:
# ALCF documentation mentions that there is a bug in Pytorch's multithreaded data loaders with
# distributed training across multiple nodes. Therefore, `num_workers` is set to 0. See also:
Expand Down

0 comments on commit 0503b22

Please sign in to comment.