-
Notifications
You must be signed in to change notification settings - Fork 276
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: pytorch/torchft#83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with pytorch/torchft#91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: 11e74f8fafa225d64376e317de39b362a8013471 Pull Request resolved: #834
- Loading branch information
Showing
9 changed files
with
350 additions
and
187 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import importlib | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
from torchtitan.config_manager import JobConfig | ||
|
||
if importlib.util.find_spec("torchft") is not None: | ||
import torchft as ft | ||
|
||
has_torchft = True | ||
else: | ||
has_torchft = False | ||
|
||
|
||
@dataclass | ||
class FTManager: | ||
manager: ft.Manager | ||
replicate_group_size: int | ||
|
||
|
||
def init_ft_manager(job: JobConfig) -> Optional[FTManager]: | ||
""" | ||
Initialize the FT manager for the given job. | ||
""" | ||
if not job.experimental.enable_torchft: | ||
return None | ||
|
||
if not has_torchft: | ||
raise ImportError("torchft is not installed. Please install it.") | ||
|
||
pg = ft.ProcessGroupBabyNCCL() | ||
manager = ft.Manager( | ||
pg=pg, | ||
min_replica_size=1, | ||
load_state_dict=None, | ||
state_dict=None, | ||
use_async_quorum=True, | ||
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_group_id}", | ||
) | ||
|
||
return FTManager(manager, job.experimental.ft_replica_group_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters