Skip to content

Commit

Permalink
Integrate TorchFT
Browse files Browse the repository at this point in the history
**Summary**
This is a WIP TorchFT integration PR.

**Current Issues**

This doesn't work at this moment as there are hanged groups when a new
group joins.

**Issue 1:**
~Group 0 and group 1 will hang during the first `should_commit` after
group 1 applying the pending state_dict from group 0.~

Fixed with: pytorch/torchft#83

**Issue 2:**
~Group 0 and group 1 will pass the `should_commit` but group 0 needs
healing which is wrong and the healing process will cause another hang.~

Fixed with: pytorch/torchft#83

**Issue 3:**
~The byproduct of issue 1 and issue 2: group 1 will continue to print
out~
```
[rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN
socketProgress: Connection closed by remote peer
devgpu051.cln3.svc.fbinfra.net<33618>
```

Fixed with pytorch/torchft#91 and several other
fixes.

**Issue 4:**
When there are 3 groups, everyone requests the state dict every step.
***How to reproduce?***
Using the `Reproduce steps` to run 2 groups, then add another group by
modifying the command.

Seems to be fixed, will need more tests.

**Issue 5:**
Hang will happen if using functional collective.
***How to reproduce?***
Pull the latest version of this PR and comment out line 41 and uncomment
line 42 in `torchtitan/utils.py`

**Reproduce steps:**

1. Patch TorchFT with pytorch/torchft#82
2. Execute lighthouse
3. Execute the following command in one terminal:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1
NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2
--experimental.enable_torchft --experimental.ft_replica_group_id=0
```
4. Wait 10 seconds, execute following command in another terminal:
```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3
NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2
--experimental.enable_torchft --experimental.ft_replica_group_id=1
```

ghstack-source-id: 1c5c3953695085b0fe9d02458c074bd63666b2e2
Pull Request resolved: #834
  • Loading branch information
fegin committed Feb 13, 2025
1 parent 22d0e70 commit 5b94fcd
Show file tree
Hide file tree
Showing 9 changed files with 372 additions and 206 deletions.
4 changes: 4 additions & 0 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"}

PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=http://localhost:29510 \
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
395 changes: 206 additions & 189 deletions torchtitan/checkpoint.py

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,30 @@ def __init__(self):
action="store_true",
)

self.parser.add_argument(
"--experimental.enable_torchft",
action="store_true",
help="Enable TorchFT integration.",
)

self.parser.add_argument(
"--experimental.ft_replica_group_id",
type=int,
default=-1,
help="The TorchFT replicate group ID of this run.",
)

self.parser.add_argument(
"--experimental.ft_replica_group_size",
type=int,
default=-1,
help="""
The number of TorchFT replicate groups. This number will be used for
dataloader to split the dataset across the replicate groups and FSDP
dimension.
""",
)

def to_dict(self):
return self.args_dict

Expand Down
52 changes: 52 additions & 0 deletions torchtitan/ft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import importlib
from dataclasses import dataclass
from typing import Optional

from torchtitan.config_manager import JobConfig

if importlib.util.find_spec("torchft") is not None:
import torchft as ft

has_torchft = True
else:
has_torchft = False


@dataclass
class FTManager:
manager: ft.Manager
replicate_group_size: int


def init_ft_manager(job: JobConfig) -> Optional[FTManager]:
"""Initialize the FT manager if TorchFT is enabled.
Args:
job (JobConfig): The job configuration.
Returns:
Optional[FTManager]: The FT manager if TorchFT is enabled, otherwise None.
"""
if not job.experimental.enable_torchft:
return None

if not has_torchft:
raise ImportError("torchft is not installed. Please install it.")

pg = ft.ProcessGroupBabyNCCL()
manager = ft.Manager(
pg=pg,
min_replica_size=1,
load_state_dict=None,
state_dict=None,
use_async_quorum=True,
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_group_id}",
)

return FTManager(manager, job.experimental.ft_replica_group_size)
50 changes: 43 additions & 7 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import copy
import functools
from typing import Any, Callable, Dict, Iterable, List
from typing import Any, Callable, Dict, Iterable, List, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -177,8 +177,41 @@ def zero_grad(self) -> None:
pass


class FTOptimizersContainer(OptimizersContainer):
def __init__(
self,
model_parts: List[nn.Module],
optimizer_kwargs: Dict[str, Any],
name: str,
ft_manager: Any,
) -> None:
import torchft as ft

super().__init__(model_parts, optimizer_kwargs, name)

# Force to initialize the optimizer state so that `optim.step()`
# won't be called by state_dict() and load_state_dict().
_ = {
k: v
for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
for k, v in sd.items()
}
self.optimizers = [
ft.Optimizer(ft_manager.manager, optim) for optim in self.optimizers
]
self.cache_state_dict: Dict[str, Any] = {}

def init_cache_state_dict(self) -> None:
self.cache_state_dict = super().state_dict()

def state_dict(self) -> Dict[str, Any]:
return self.cache_state_dict


def build_optimizers(
model_parts: List[nn.Module], job_config: JobConfig
model_parts: List[nn.Module],
job_config: JobConfig,
ft_manager: Optional[Any] = None,
) -> OptimizersContainer:
"""Create a OptimizersContainer for the given model parts and job config.
Expand Down Expand Up @@ -213,11 +246,14 @@ def build_optimizers(
"foreach": not fused,
}

return (
OptimizersContainer(model_parts, optimizer_kwargs, name)
if not optim_in_bwd
else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
)
if optim_in_bwd and ft_manager:
raise ValueError("TorchFT is not supported with optimizers in backward.")
elif optim_in_bwd:
return OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
elif ft_manager:
return FTOptimizersContainer(model_parts, optimizer_kwargs, name, ft_manager)
else:
return OptimizersContainer(model_parts, optimizer_kwargs, name)


class LRSchedulersContainer(Stateful):
Expand Down
21 changes: 17 additions & 4 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dataclasses import dataclass
from functools import cached_property
from typing import Any, Optional

from torch.distributed.device_mesh import init_device_mesh

Expand All @@ -24,6 +25,7 @@ class ParallelDims:
pp: int
world_size: int
enable_loss_parallel: bool
ft_manager: Optional[Any]

def __post_init__(self):
self._validate()
Expand Down Expand Up @@ -56,13 +58,24 @@ def build_mesh(self, device_type):
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
):
if d > 1:
if d > 1 or (name == "dp_replicate" and self.ft_manager is not None):
dims.append(d)
names.append(name)

logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
if self.ft_manager is None:
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
else:
from torchft.process_group import ft_init_device_mesh

mesh = ft_init_device_mesh(
device_type=device_type,
mesh_shape=dims,
mesh_dim_names=names,
replicate_dim=names.index("dp_replicate"),
manager=self.ft_manager.manager,
)

# Create all the submesh here to ensure all required process groups are
# initialized:
Expand All @@ -73,7 +86,7 @@ def build_mesh(self, device_type):
# Mesh for loss all-reduce
dp_cp_mesh_dim_names = []

if self.dp_replicate_enabled:
if self.dp_replicate_enabled or ft_manager is not None:
dp_mesh_dim_names.append("dp_replicate")
dp_cp_mesh_dim_names.append("dp_replicate")
if self.dp_shard_enabled:
Expand Down Expand Up @@ -101,7 +114,7 @@ def dp_enabled(self):

@property
def dp_replicate_enabled(self):
return self.dp_replicate > 1
return self.dp_replicate > 1 or self.ft_manager is not None

@property
def dp_shard_enabled(self):
Expand Down
18 changes: 18 additions & 0 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import contextlib
import copy
import gc
import importlib
import math
Expand All @@ -18,6 +19,7 @@
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
import torchft as ft
from torch import distributed as dist
from torch._utils import _get_available_device_type, _get_device_module
from torch.distributed.device_mesh import DeviceMesh
Expand All @@ -38,6 +40,12 @@ def get_device_info():


def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float:
if isinstance(mesh, ft.process_group._FlattenDeviceMesh):
x = funcol.all_reduce(
x, reduceOp=reduceOp, group=mesh.managed_mesh.replicate_pg
)
mesh = mesh.managed_mesh.mesh

if isinstance(x, DTensor):
# functional collectives do not support DTensor inputs
x = x.full_tensor()
Expand Down Expand Up @@ -407,6 +415,16 @@ def clip_grad_norm_(
if isinstance(total_norm, DTensor):
# Will reach here if any non-PP parallelism is used.
# If only using PP, total_norm will be a local tensor.
mesh = total_norm._spec.mesh
if isinstance(mesh, ft.process_group.ManagedDeviceMesh):
# The gradients along the replicated dim has already been reduced.
# So we don't need another reducution beforing removing the
# replicate dimension
local_tensor = total_norm.to_local()
placements = list(copy.copy(total_norm._spec.placements))
placements.pop(mesh.replicate_dim)
total_norm = DTensor.from_local(local_tensor, mesh.mesh, placements)

total_norm = total_norm.full_tensor()

if pp_mesh is not None:
Expand Down
13 changes: 8 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torchtitan.checkpoint import CheckpointManager, TrainState
from torchtitan.config_manager import JobConfig
from torchtitan.float8 import Float8Handler
from torchtitan.ft import init_ft_manager
from torchtitan.logging import init_logger, logger
from torchtitan.metrics import build_device_memory_monitor, build_metric_logger
from torchtitan.parallelisms import ParallelDims
Expand All @@ -42,6 +43,10 @@ def main(job_config: JobConfig):
# take control of garbage collection to avoid stragglers
gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)

device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
device_module.set_device(device)
ft_manager = init_ft_manager(job_config)

# init distributed
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
Expand All @@ -52,9 +57,8 @@ def main(job_config: JobConfig):
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=not job_config.training.disable_loss_parallel,
ft_manager=ft_manager,
)
device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
device_module.set_device(device)
utils.init_distributed(job_config)
# initialize device memory monitor and get peak flops for MFU calculation
device_memory_monitor = build_device_memory_monitor()
Expand Down Expand Up @@ -187,7 +191,7 @@ def loss_fn(pred, labels):
)

# build optimizer after applying parallelisms to the model
optimizers = train_spec.build_optimizers_fn(model_parts, job_config)
optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager)
lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config)

train_state = TrainState()
Expand All @@ -200,6 +204,7 @@ def loss_fn(pred, labels):
lr_schedulers=lr_schedulers,
states={"train_state": train_state},
job_config=job_config,
ft_manager=ft_manager,
)

if job_config.checkpoint.create_seed_checkpoint:
Expand Down Expand Up @@ -240,8 +245,6 @@ def loss_fn(pred, labels):
time_last_log = time.perf_counter()
device_memory_monitor.reset_peak_stats()

checkpoint.reset()

# train loop
logger.info(
f"Training starts at step {train_state.step + 1}, "
Expand Down
1 change: 0 additions & 1 deletion train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ pipeline_parallel_degree = 1
[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
Expand Down

0 comments on commit 5b94fcd

Please sign in to comment.