Skip to content

Commit

Permalink
Integrate TorchFT
Browse files Browse the repository at this point in the history
**Summary**
This is a WIP TorchFT integration PR.

**Current Issues**

This doesn't work at this moment as there are hanged groups when a new
group joins.

**Issue 1:**
~Group 0 and group 1 will hang during the first `should_commit` after
group 1 applying the pending state_dict from group 0.~

Fixed with: pytorch/torchft#83

**Issue 2:**
~Group 0 and group 1 will pass the `should_commit` but group 0 needs
healing which is wrong and the healing process will cause another hang.~

Fixed with: pytorch/torchft#83

**Issue 3:**
~The byproduct of issue 1 and issue 2: group 1 will continue to print
out~
```
[rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN
socketProgress: Connection closed by remote peer
devgpu051.cln3.svc.fbinfra.net<33618>
```

Fixed with pytorch/torchft#91 and several other
fixes.

**Issue 4:**
When there are 3 groups, everyone requests the state dict every step.
***How to reproduce?***
Using the `Reproduce steps` to run 2 groups, then add another group by
modifying the command.

Seems to be fixed, will need more tests.

**Issue 5:**
Hang will happen if using functional collective.
***How to reproduce?***
Pull the latest version of this PR and comment out line 41 and uncomment
line 42 in `torchtitan/utils.py`

**Reproduce steps:**

1. Patch TorchFT with pytorch/torchft#82
2. Execute lighthouse
3. Execute the following command in one terminal:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1
NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2
--experimental.enable_torchft --experimental.ft_replica_group_id=0
```
4. Wait 10 seconds, execute following command in another terminal:
```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3
NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2
--experimental.enable_torchft --experimental.ft_replica_group_id=1
```

ghstack-source-id: bf6f0c51100cf2a1c2fb25a405b6c7592694b323
Pull Request resolved: #834
  • Loading branch information
fegin committed Feb 12, 2025
1 parent 437b5e7 commit 330744f
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 97 deletions.
4 changes: 4 additions & 0 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"}

PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=http://localhost:29510 \
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
202 changes: 118 additions & 84 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
set_model_state_dict,
Expand Down Expand Up @@ -144,50 +145,29 @@ def __init__(
lr_schedulers: LRSchedulersContainer,
states: Dict[str, Any],
job_config: JobConfig,
ft_manager: Optional[Any] = None,
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
self.keep_latest_k = ckpt_config.keep_latest_k
self.ft_manager = ft_manager
self.enable_staging = (
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
) or self.ft_manager

if not self.enable_checkpoint:
if not self.enable_checkpoint and self.ft_manager is None:
return
"""
Note: Pipeline Parallelism and Virtual Stages
1. even for simple PP schedules, there is a separate optimizer each PP rank.
rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model.
rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1.
When saving, these collide and one of them is lost. Then when reloading, only one stage can
restore its optimizer states, others will error.
The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan
by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer.
2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also
requiring us to reason about multiple 'optim' objects locally.
We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object
into one state dict before saving/loading. We rely on the individual state_dicts to not collide,
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
support described in (1).
3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
optimizers do, so it's hard to write a generic 'flattener' utility.
TODO: This is currently unsolved and needs a fix.
"""
self.states = states

self.states.update(
{
"model": ModelWrapper(model_parts),
"optimizer": optimizers,
"dataloader": dataloader,
"lr_scheduler": lr_schedulers,
}
self._initialize_states(
states, dataloader, model_parts, optimizers, lr_schedulers
)

async_mode = ckpt_config.async_mode.lower()
self.staging = False
self.sending_to_checkpoint_mp = False
self.staging_id = None
self.cpu_offload_state_dict = None
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval_type = (
IntervalType.SECONDS
Expand All @@ -202,6 +182,7 @@ def __init__(
if async_mode == AsyncMode.ASYNC or self.interval_type == IntervalType.SECONDS:
self.pg = dist.new_group(backend="gloo")

self.keep_latest_k = ckpt_config.keep_latest_k
self.model_weights_only = ckpt_config.model_weights_only
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]

Expand All @@ -225,10 +206,6 @@ def __init__(
daemon=True,
)
self.mp.start()
self.cpu_offload_state_dict = None
self.staging = False
self.staging_id = None
self.staging_stream = torch.cuda.Stream()
else:
raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}")

Expand All @@ -242,8 +219,61 @@ def __del__(self):
self.mp.join()

def reset(self) -> None:
# We need to stage the local state if another replicate joins during the
# first step.
if self.ft_manager:
self.cpu_staging(None)
self.begin_time = time.monotonic()

def _initialize_states(
self,
states: Dict[str, Any],
dataloader: DataLoader,
model_parts: List[nn.Module],
optimizers: OptimizersContainer,
lr_schedulers: LRSchedulersContainer,
) -> None:
"""
Note: Pipeline Parallelism and Virtual Stages
1. Even for simple PP schedules, there is a separate optimizer each PP rank.
rank0's optimizer would have a param_group[0] which refers to layers.0 in the
original model. rank1's would _also_ have a param_group[0], since it's index based,
but referring to layers.1.
When saving, these collide and one of them is lost. Then when reloading, only one
stage can restore its optimizer states, others will error.
The solution to this problem is optimizer flattening: it landed in #127071
and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict'
kwarg to DCP functions called in the OptimizerContainer.
2. With complex PP schedules, we have multiple model chunks per pp rank. This
compounds challenge (1) by also requiring us to reason about multiple 'optim'
objects locally.
We solve this in the Model and Optimizer wrapper classes by flattening the
state dicts from each object into one state dict before saving/loading.
We rely on the individual state_dicts to not collide, which is gauranteed for
the model by correct pipeline splitting and for the optimizer by the flattening
support described in (1).
3. LR schedulers also index model states like optimizers and would need to be
flattened properly to support resharding. Unfortunately, the implementations of
different lr_schedulers do not follow a clear pattern like optimizers do, so it's
hard to write a generic 'flattener' utility.
TODO: This is currently unsolved and needs a fix.
"""
self.states = states
self.states.update(
{
"model": ModelWrapper(model_parts),
"optimizer": optimizers,
"dataloader": dataloader,
"lr_scheduler": lr_schedulers,
}
)

def _create_checkpoint_id(self, step: int) -> str:
return os.path.join(self.folder, f"step-{step}")

Expand Down Expand Up @@ -326,31 +356,8 @@ def _async_wait(self) -> None:
self.async_future.result()

def _async_with_pinned_memory(self, checkpoint_id: str) -> None:
try:
from torch.distributed._state_dict_utils import (
_copy_state_dict,
_create_cpu_state_dict,
)
except ImportError as e:
raise ImportError(
"Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
) from e
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
if self.cpu_offload_state_dict is None:
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
self.cpu_offload_state_dict = _create_cpu_state_dict(
state_dict, pin_memory=True, share_memory=True
)

logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
with torch.cuda.stream(self.staging_stream):
self.cpu_offload_state_dict = _copy_state_dict(
state_dict,
self.cpu_offload_state_dict,
non_blocking=True,
)
self.staging = True
self.staging_id = checkpoint_id
self.cpu_staging(checkpoint_id)
self.sending_to_checkpoint_mp = True

def save(self, curr_step: int, force: bool = False) -> None:
"""
Expand All @@ -360,6 +367,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
for initial seed checkpoint.
"""
if not self._should_save(curr_step, force):
if self.ft_manager:
self.cpu_staging(None)
return

begin = time.monotonic()
Expand All @@ -383,26 +392,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
f"in {time.monotonic() - begin:.2f} seconds."
)

def cpu_staging(self, checkpoint_id: Optional[str]) -> None:
"""Offload state_dict to CPU memory"""
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
if self.cpu_offload_state_dict is None:
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
self.cpu_offload_state_dict = _create_cpu_state_dict(
state_dict, pin_memory=True, share_memory=True
)

logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
with torch.cuda.stream(self.staging_stream):
self.cpu_offload_state_dict = _copy_state_dict(
state_dict,
self.cpu_offload_state_dict,
non_blocking=True,
)
self.staging = True
self.staging_id = checkpoint_id

def wait_for_staging(self) -> None:
if not self.staging_stream.query():
self.staging_stream.synchronize()
self.staging = False

def staging_results(self) -> Dict[str, Any]:
self.maybe_wait_for_staging()
return self.cpu_offload_state_dict

def maybe_wait_for_staging(self) -> None:
if (
self.enable_checkpoint
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
and self.staging
):
if not self.staging_stream.query():
self.staging_stream.synchronize()

def sync_func():
self.mp_queue_send.put_nowait(
(self.cpu_offload_state_dict, self.staging_id)
)

# This may be a faster way to do zero-overhead checkpointing staging
# checkpointing but we need more thorough investigation before
# swithing to this method.
# self.my_thread = threading.Thread(target=func).start()
sync_func()
self.staging = False
if self.enable_staging and self.staging:
self.wait_for_staging()

if self.sending_to_checkpoint_mp:
# Copy the sync staging result to another process.
def sync_func():
self.mp_queue_send.put_nowait(
(self.cpu_offload_state_dict, self.staging_id)
)

# This may be a faster way to do zero-overhead checkpointing staging
# checkpointing but we need more thorough investigation before
# swithing to this method.
# self.my_thread = threading.Thread(target=func).start()
sync_func()
self.sending_to_checkpoint_mp = False

def load(self, step: int = -1) -> bool:
if not self.enable_checkpoint:
Expand Down
13 changes: 13 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,19 @@ def __init__(self):
action="store_true",
)

self.parser.add_argument(
"--experimental.enable_torchft",
action="store_true",
help="Enable TorchFT integration.",
)

self.parser.add_argument(
"--experimental.ft_replica_group_id",
type=int,
default=-1,
help="The FT replicate group of this run.",
)

def to_dict(self):
return self.args_dict

Expand Down
58 changes: 58 additions & 0 deletions torchtitan/ft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import importlib
from typing import Any, Callable, Optional

from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict

from torchtitan.config_manager import JobConfig

if importlib.util.find_spec("torchft") is not None:
import torchft as ft

has_torchft = True
else:
has_torchft = False


def init_ft_manager(job: JobConfig) -> Optional["ft.Manager"]:
"""
Initialize the FT manager for the given job.
"""
if not job.experimental.enable_torchft:
return None

if not has_torchft:
raise ImportError("torchft is not installed. Please install it.")

pg = ft.ProcessGroupBabyNCCL()
manager = ft.Manager(
pg=pg,
min_replica_size=1,
load_state_dict=None,
state_dict=None,
use_async_quorum=True,
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_group_id}",
)

return manager


def set_ft_state_dict_fns(manager: Optional["ft.Manager"], ckpt_manager) -> None:
"""
Set the state dict for the given manager.
"""
if manager is None:
return

def state_dict():
ret = {}
for k, v in ckpt_manager.staging_results().items():
if k in {"model", "optimizer", "lr_schedulers"}:
ret[k] = v
return ret

def load_state_dict(state_dict):
assert state_dict is not None
for k, v in state_dict.items():
ckpt_manager.states[k].load_state_dict(v)

manager.set_state_dict_fns(load_state_dict, state_dict)
Loading

0 comments on commit 330744f

Please sign in to comment.