Skip to content

Commit

Permalink
Update on "[WIP][RFC] TorchFT integration"
Browse files Browse the repository at this point in the history
**Summary**
This is a WIP TorchFT integration PR.

**Current Issues**

This doesn't work at this moment as there are hanged groups when a new group joins. 

**Issue 1:**
~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~

Fixed with: pytorch/torchft#83

**Issue 2:**
~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~

Fixed with: pytorch/torchft#83

**Issue 3:**
~The byproduct of issue 1 and issue 2: group 1 will continue to print out~
```
[rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618>
```

Fixed with pytorch/torchft#91 and several other fixes.

**Issue 4:**
When there are 3 groups, everyone requests the state dict every step.
***How to reproduce?***
Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. 

Seems to be fixed, will need more tests.

**Issue 5:**
Hang will happen if using functional collective.
***How to reproduce?***
Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py`


**Reproduce steps:**

1. Patch TorchFT with pytorch/torchft#82
2. Execute lighthouse
3. Execute the following command in one terminal:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```
4. Wait 10 seconds, execute following command in another terminal:
```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```



[ghstack-poisoned]
  • Loading branch information
fegin committed Feb 3, 2025
2 parents ae425b8 + d95dff9 commit dfcb08d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 39 deletions.
38 changes: 1 addition & 37 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,41 +156,6 @@ def __init__(
if not self.enable_checkpoint and self.ft_manager is None:
return

<<<<<<< HEAD
1. even for simple PP schedules, there is a separate optimizer each PP rank.
rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model.
rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1.
When saving, these collide and one of them is lost. Then when reloading, only one stage can
restore its optimizer states, others will error.

The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan
by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer.

2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also
requiring us to reason about multiple 'optim' objects locally.

We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object
into one state dict before saving/loading. We rely on the individual state_dicts to not collide,
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
support described in (1).

3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
optimizers do, so it's hard to write a generic 'flattener' utility.

TODO: This is currently unsolved and needs a fix.
"""
self.states = states
self.states.update(
{
"model": ModelWrapper(model_parts),
"optimizer": optimizers,
"dataloader": dataloader,
"lr_scheduler": lr_schedulers,
}
)
=======
self._initialize_states(
states, dataloader, model_parts, optimizers, lr_schedulers
)
Expand All @@ -201,7 +166,6 @@ def __init__(
self.staging_id = None
self.cpu_offload_state_dict = None
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
>>>>>>> 3430d99 ([WIP][RFC] TorchFT integration)

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval_type = (
Expand Down Expand Up @@ -305,9 +269,9 @@ def _initialize_states(
"model": ModelWrapper(model_parts),
"optimizer": optimizers,
"dataloader": dataloader,
"lr_scheduler": lr_schedulers,
}
)
self.states.update(lr_schedulers.get_lr_scheduler_state())

def _create_checkpoint_id(self, step: int) -> str:
return os.path.join(self.folder, f"step-{step}")
Expand Down
5 changes: 3 additions & 2 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,10 @@ def clip_grad_norm_(
# If only using PP, total_norm will be a local tensor.
mesh = total_norm._spec.mesh
if isinstance(mesh, ft.process_group.ManagedDeviceMesh):
# The gradients along the replicated dim has been reduced.
# So we don't need another reducution beforing removing the
# replicate dimension
local_tensor = total_norm.to_local()
dist.all_reduce(local_tensor, op=dist.ReduceOp.AVG, group=mesh.replicate_pg)

placements = list(copy.copy(total_norm._spec.placements))
placements.pop(mesh.replicate_dim)
mesh = mesh.mesh
Expand Down

0 comments on commit dfcb08d

Please sign in to comment.