Skip to content

Commit

Permalink
Allow users to customize dataloader
Browse files Browse the repository at this point in the history
ghstack-source-id: 421bdd3b507197725e55992b0a11ecd7704da1c0
Pull Request resolved: #836
  • Loading branch information
fegin committed Feb 13, 2025
1 parent ab94a99 commit 0f04646
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 91 deletions.
5 changes: 1 addition & 4 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@

from torchtitan import utils
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_tokenizer
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.models import model_name_to_tokenizer
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import ParallelDims
from torchtitan.train_spec import get_train_spec
Expand Down Expand Up @@ -83,8 +81,7 @@ def estimate_memory(job_config: JobConfig):
model_name = job_config.model.name

# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
tokenizer = train_spec.tokenizer_cls(job_config.model.tokenizer_path)

train_context = utils.get_train_context(
parallel_dims.loss_parallel_enabled,
Expand Down
12 changes: 5 additions & 7 deletions tests/unit_tests/test_dataset_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
# LICENSE file in the root directory of this source tree.

import torch
from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.tokenizer import build_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader


class TestDatasetCheckpointing:
Expand Down Expand Up @@ -41,13 +40,12 @@ def test_c4_resumption(self):
def _build_dataloader(
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
):
tokenizer = build_tokenizer("tiktoken", "./tests/assets/test_tiktoken.model")
return build_hf_data_loader(
return build_hf_dataloader(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
tokenizer_path="./tests/assets/test_tiktoken.model",
batch_size=1,
seq_len=1024,
world_size=4,
rank=0,
dp_world_size=4,
dp_rank=0,
)
3 changes: 3 additions & 0 deletions tests/unit_tests/test_train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn as nn
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_hf_dataloader
from torchtitan.models.llama import parallelize_llama, pipeline_llama
from torchtitan.optimizer import (
build_lr_schedulers,
Expand Down Expand Up @@ -60,6 +61,7 @@ def test_register_train_spec(self):
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
)
register_train_spec(spec)
new_spec = get_train_spec("fake")
Expand All @@ -78,6 +80,7 @@ def test_optim_hook(self):
pipelining_fn=pipeline_llama,
build_optimizers_fn=fake_build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
)
register_train_spec(spec)
new_spec = get_train_spec("fake2")
Expand Down
91 changes: 91 additions & 0 deletions torchtitan/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

import pickle
from abc import ABC, abstractmethod
from typing import Any, Callable, TypeAlias

from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader


class BaseDataLoader(Stateful, ABC):
"""Base class for all dataloaders.
This is used to enforce that all dataloaders have the methods defined in ``Stateful``,
``state_dict()`` and ``load_state_dict()``.
"""

@abstractmethod
def __iter__(self):
...


class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader):
"""Dataloader that is aware of distributed data parallelism.
This dataloader is used to load data in a distributed data parallel fashion. It also
utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary
methods such as ``__iter__``.
Args:
dataset (IterableDataset): The dataset to iterate over.
dp_rank: Data parallelism rank for this dataloader.
dp_world_size: The world size of the data parallelism.
batch_size: The batch size to use for each iteration.
"""

dp_rank: int
dp_world_size: int
batch_size: int

def __init__(
self,
dataset: IterableDataset,
dp_rank: int,
dp_world_size: int,
batch_size: int,
):
self.dp_world_size = dp_world_size
self.dp_rank = dp_rank
self.batch_size = batch_size
super().__init__(dataset, batch_size)
self._rank_id = f"dp_rank_{dp_rank}"

def state_dict(self) -> dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions.
return {
# We don't have to use pickle as DCP will serialize the state_dict. However,
# we have to keep this for backward compatibility.
self._rank_id: pickle.dumps(super().state_dict()),
"world_size": self.dp_world_size,
}

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
# State being empty is valid.
if not state_dict:
return

if self._rank_id not in state_dict:
logger.warning(
f"DataLoader state is empty for dp rank {self.dp_rank}, "
"expected key {self._rank_id}"
)
return

assert self.dp_world_size == state_dict["world_size"], (
"dp_degree is inconsistent before and after checkpoint, "
"dataloader resharding is not supported yet."
)
# We don't have to use pickle as DCP will serialize the state_dict. However, we have to
# keep this for backward compatibility.
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))


DataLoaderBuilder: TypeAlias = Callable[[...], BaseDataLoader]
4 changes: 2 additions & 2 deletions torchtitan/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer import build_tokenizer

__all__ = [
"build_hf_data_loader",
"build_hf_dataloader",
"build_tokenizer",
]
85 changes: 30 additions & 55 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,28 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pickle
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional

import torch

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.dataloader import ParallelAwareDataloader

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging import logger

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node


def _load_c4_dataset(dataset_path: str):
"""Load C4 dataset with default configuration."""
return load_dataset(dataset_path, name="en", split="train", streaming=True)


def _process_c4_text(sample: Dict[str, Any]) -> str:
def _process_c4_text(sample: dict[str, Any]) -> str:
"""Process C4 dataset sample text."""
return sample["text"]

Expand Down Expand Up @@ -75,8 +75,8 @@ def __init__(
dataset_path: Optional[str],
tokenizer: Tokenizer,
seq_len: int = 2048,
world_size: int = 1,
rank: int = 0,
dp_rank: int = 0,
dp_world_size: int = 1,
infinite: bool = False,
) -> None:
# Force lowercase for consistent comparison
Expand All @@ -88,15 +88,15 @@ def __init__(
ds = dataset_loader(path)

self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
self._tokenizer = tokenizer
self.seq_len = seq_len
self.infinite = infinite
self._text_processor = text_processor

# Variables for checkpointing
self._sample_idx = 0
self._all_tokens: List[int] = []
self._all_tokens: list[int] = []

def _get_data_iter(self):
if self._sample_idx == 0:
Expand Down Expand Up @@ -142,56 +142,31 @@ def state_dict(self):
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}


class DPAwareDataLoader(StatefulDataLoader, Stateful):
"""
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
"""

def __init__(
self, dp_rank: int, hf_ds: IterableDataset, batch_size: int, world_size: int
):
super().__init__(hf_ds, batch_size)
self._dp_rank = dp_rank
self._rank_id = f"dp_rank_{dp_rank}"
# Data loader resharding is not yet supported, so we need to store the world size to compare during loading
# raise error if dp_word_size does not match.
self._world_size = world_size

def state_dict(self) -> Dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions
return {
self._rank_id: pickle.dumps(super().state_dict()),
"world_size": self._world_size,
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid
if not state_dict:
return

if self._rank_id not in state_dict:
logger.warning(
f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}"
)
return
assert (
self._world_size == state_dict["world_size"]
), "dp_degree is inconsistent before and after checkpoint, dataloader resharding is not supported yet."
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))


def build_hf_data_loader(
def build_hf_dataloader(
dataset_name: str,
dataset_path: Optional[str],
tokenizer: Tokenizer,
batch_size: int,
seq_len: int,
world_size: int,
rank: int,
dp_rank: int,
dp_world_size: int,
infinite: bool = True,
):
) -> ParallelAwareDataloader:
"""Build a data loader for HuggingFace datasets."""

hf_ds = HuggingFaceDataset(
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
seq_len=seq_len,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
infinite=infinite,
)

return ParallelAwareDataloader(
dataset=hf_ds,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
)
return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size, world_size=world_size)
3 changes: 3 additions & 0 deletions torchtitan/datasets/tokenizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from torchtitan.logging import logger


__all__ = ["build_tokenizer", "TikTokenizer"]


def build_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer:
logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}")
if tokenizer_type == "tiktoken":
Expand Down
3 changes: 0 additions & 3 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,3 @@
# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models.llama # noqa: F401


model_name_to_tokenizer = {"llama3": "tiktoken"}
4 changes: 4 additions & 0 deletions torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer import TikTokenizer
from torchtitan.models.llama.model import Transformer, TransformerModelArgs
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.train_spec import register_train_spec, TrainSpec
Expand Down Expand Up @@ -65,5 +67,7 @@
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
tokenizer_cls=TikTokenizer,
)
)
11 changes: 6 additions & 5 deletions torchtitan/train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


from dataclasses import dataclass
from typing import Callable, Dict, Protocol, Type, TypeAlias
from typing import Callable, Protocol, Type, TypeAlias

import torch.nn as nn
from torch.distributed.pipelining.schedules import _PipelineSchedule

from torchtitan.config_manager import JobConfig
from torchtitan.dataloader import DataLoaderBuilder
from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer


Expand Down Expand Up @@ -53,15 +54,15 @@ def from_model_args(args: BaseModelArgs) -> nn.Module:
class TrainSpec:
name: str
cls: Type[nn.Module]
config: Dict[str, BaseModelArgs]
config: dict[str, BaseModelArgs]
parallelize_fn: Callable[[nn.Module], None]
pipelining_fn: Callable[
[nn.Module], tuple[_PipelineSchedule, list[nn.Module], bool, bool]
]
build_optimizers_fn: OptimizersBuilder
build_lr_schedulers_fn: LRSchedulersBuilder

# TODO: Add a ``build_dataloader_fn``
build_dataloader_fn: DataLoaderBuilder
tokenizer_cls: Type[Tokenizer]

# TODO: Add a FQN convert fn to allow users to load checkpoints from
# HuggingFace or other sources that have different FQN conventions.
Expand Down
Loading

0 comments on commit 0f04646

Please sign in to comment.