-
Notifications
You must be signed in to change notification settings - Fork 279
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ghstack-source-id: 421bdd3b507197725e55992b0a11ecd7704da1c0 Pull Request resolved: #836
- Loading branch information
Showing
11 changed files
with
156 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
import pickle | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Callable, TypeAlias | ||
|
||
from torch.distributed.checkpoint.stateful import Stateful | ||
from torch.utils.data import IterableDataset | ||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
|
||
|
||
class BaseDataLoader(Stateful, ABC): | ||
"""Base class for all dataloaders. | ||
This is used to enforce that all dataloaders have the methods defined in ``Stateful``, | ||
``state_dict()`` and ``load_state_dict()``. | ||
""" | ||
|
||
@abstractmethod | ||
def __iter__(self): | ||
... | ||
|
||
|
||
class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader): | ||
"""Dataloader that is aware of distributed data parallelism. | ||
This dataloader is used to load data in a distributed data parallel fashion. It also | ||
utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary | ||
methods such as ``__iter__``. | ||
Args: | ||
dataset (IterableDataset): The dataset to iterate over. | ||
dp_rank: Data parallelism rank for this dataloader. | ||
dp_world_size: The world size of the data parallelism. | ||
batch_size: The batch size to use for each iteration. | ||
""" | ||
|
||
dp_rank: int | ||
dp_world_size: int | ||
batch_size: int | ||
|
||
def __init__( | ||
self, | ||
dataset: IterableDataset, | ||
dp_rank: int, | ||
dp_world_size: int, | ||
batch_size: int, | ||
): | ||
self.dp_world_size = dp_world_size | ||
self.dp_rank = dp_rank | ||
self.batch_size = batch_size | ||
super().__init__(dataset, batch_size) | ||
self._rank_id = f"dp_rank_{dp_rank}" | ||
|
||
def state_dict(self) -> dict[str, Any]: | ||
# Store state only for dp rank to avoid replicating the same state across other dimensions. | ||
return { | ||
# We don't have to use pickle as DCP will serialize the state_dict. However, | ||
# we have to keep this for backward compatibility. | ||
self._rank_id: pickle.dumps(super().state_dict()), | ||
"world_size": self.dp_world_size, | ||
} | ||
|
||
def load_state_dict(self, state_dict: dict[str, Any]) -> None: | ||
# State being empty is valid. | ||
if not state_dict: | ||
return | ||
|
||
if self._rank_id not in state_dict: | ||
logger.warning( | ||
f"DataLoader state is empty for dp rank {self.dp_rank}, " | ||
"expected key {self._rank_id}" | ||
) | ||
return | ||
|
||
assert self.dp_world_size == state_dict["world_size"], ( | ||
"dp_degree is inconsistent before and after checkpoint, " | ||
"dataloader resharding is not supported yet." | ||
) | ||
# We don't have to use pickle as DCP will serialize the state_dict. However, we have to | ||
# keep this for backward compatibility. | ||
super().load_state_dict(pickle.loads(state_dict[self._rank_id])) | ||
|
||
|
||
DataLoaderBuilder: TypeAlias = Callable[[...], BaseDataLoader] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.