From 43a386d835e05a6513ea87fe15e69cbe6d2005ff Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 13 Feb 2025 15:11:57 -0800 Subject: [PATCH] Integrate TorchFT **Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: https://github.com/pytorch/torchft/pull/83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: https://github.com/pytorch/torchft/pull/83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with https://github.com/pytorch/torchft/pull/91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with https://github.com/pytorch/torchft/pull/82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: 11e74f8fafa225d64376e317de39b362a8013471 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/834 --- run_llama_train.sh | 4 + torchtitan/checkpoint.py | 365 ++++++++++++----------- torchtitan/config_manager.py | 24 ++ torchtitan/ft.py | 41 +++ torchtitan/optimizer.py | 50 +++- torchtitan/parallelisms/parallel_dims.py | 21 +- torchtitan/utils.py | 18 ++ train.py | 13 +- train_configs/llama3_8b.toml | 1 - 9 files changed, 350 insertions(+), 187 deletions(-) create mode 100644 torchtitan/ft.py diff --git a/run_llama_train.sh b/run_llama_train.sh index a69c967a7..cbd98a9f8 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then overrides="$*" fi +TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"} + PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +TORCHFT_LIGHTHOUSE=http://localhost:29510 \ +TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ train.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 5c35fb8d9..bbd5f56f7 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -13,12 +13,13 @@ from dataclasses import dataclass, field from io import BytesIO from multiprocessing import get_context -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn +from torch.distributed._state_dict_utils import _create_cpu_state_dict from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, set_model_state_dict, @@ -28,16 +29,13 @@ from torch.utils.data import DataLoader from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP + +from torchtitan.ft import FTManager from torchtitan.logging import init_logger, logger from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer from torchtitan.utils import GarbageCollection -class IntervalType(enum.Enum): - SECONDS = enum.auto() - STEPS = enum.auto() - - class AsyncMode(str, enum.Enum): DISABLED = "disabled" ASYNC = "async" @@ -84,12 +82,13 @@ def load_state_dict(self, state_dict) -> None: class ModelWrapper(Stateful): def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None: self.model = [model] if isinstance(model, nn.Module) else model - - def state_dict(self) -> Dict[str, Any]: - return { + self.cache_state_dict = { k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items() } + def state_dict(self) -> Dict[str, Any]: + return self.cache_state_dict + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: func = functools.partial( set_model_state_dict, @@ -151,13 +150,37 @@ def __init__( lr_schedulers: LRSchedulersContainer, states: Dict[str, Any], job_config: JobConfig, + ft_manager: Optional[FTManager] = None, ) -> None: ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint - self.keep_latest_k = ckpt_config.keep_latest_k + self.ft_manager = ft_manager + if self.ft_manager: - if not self.enable_checkpoint: + optimizers.init_cache_state_dict() + + def state_dict(): + ret = {} + for k, v in self.states.items(): + if k in {"model", "optimizer", "lr_schedulers", "train_state"}: + ret[k] = v.state_dict() + return ret + + def load_state_dict(state_dict): + assert state_dict is not None + for k, v in state_dict.items(): + self.states[k].load_state_dict(v) + + ft_manager.manager.set_state_dict_fns(load_state_dict, state_dict) + + async_mode = ckpt_config.async_mode.lower() + self.enable_staging = ( + self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM + ) or self.ft_manager + + if not self.enable_checkpoint and self.ft_manager is None: return + """ Note: Pipeline Parallelism and Virtual Stages @@ -192,20 +215,19 @@ def __init__( } ) + self.staging = False + self.sending_to_checkpoint_mp = False + self.staging_id = None + self.cpu_offload_state_dict = None + self.staging_stream = torch.cuda.Stream() if self.enable_staging else None + self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) - self.interval_type = ( - IntervalType.SECONDS - if ckpt_config.interval_type == "seconds" - else IntervalType.STEPS - ) self.interval = ckpt_config.interval - self.begin_time = 0 - self.time_sync_work = None - self.time_sync_result = None async_mode = ckpt_config.async_mode.lower() - if async_mode == AsyncMode.ASYNC or self.interval_type == IntervalType.SECONDS: + if async_mode == AsyncMode.ASYNC: self.pg = dist.new_group(backend="gloo") + self.keep_latest_k = ckpt_config.keep_latest_k self.model_weights_only = ckpt_config.model_weights_only self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] self.exclude_from_loading = ckpt_config.exclude_from_loading @@ -230,10 +252,6 @@ def __init__( daemon=True, ) self.mp.start() - self.cpu_offload_state_dict = None - self.staging = False - self.staging_id = None - self.staging_stream = torch.cuda.Stream() else: raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}") @@ -246,8 +264,153 @@ def __del__(self): self.mp_queue_send.put(Terminate()) self.mp.join() - def reset(self) -> None: - self.begin_time = time.monotonic() + def save(self, curr_step: int, force: bool = False) -> None: + """ + force = True will force the checkpoint to be saved, even if the interval + has not been reached. + This only happens when train_state.step == job_config.training.steps, or + for initial seed checkpoint. + """ + if not self._should_save(curr_step, force): + return + + logger.info("Saving the checkpoint (or staging if async is enabled).") + begin = time.monotonic() + + if not self.ft_manager or self.ft_manager.manager.participating_rank() == 0: + checkpoint_id = self._create_checkpoint_id(curr_step) + self._async_wait() + # This GC is called for async checkpoint as it is useless to do + # GC right after async_save -- the CPU memory is not able to be + # freed until _async_wait() + if force: + self._save_last_step(curr_step) + elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: + GarbageCollection.collect("GC collection invoked by checkpointer.") + self._async_with_pinned_memory(checkpoint_id) + elif self.async_mode == AsyncMode.ASYNC: + GarbageCollection.collect("GC collection invoked by checkpointer.") + self.async_future = dcp.async_save( + self.states, checkpoint_id=checkpoint_id, process_group=self.pg + ) + else: + save_with_gc(self.states, checkpoint_id=checkpoint_id) + self._purge_stale_checkpoints() + + logger.info( + "Finished saving the checkpoint (or staging if async is enabled)" + f"in {time.monotonic() - begin:.2f} seconds." + ) + elif self.ft_manager: + logger.info("Waiting for TorchFT replicated group 0 to save the checkpoint") + time.sleep(1) + + def load(self, step: int = -1) -> bool: + if not self.enable_checkpoint or not os.path.isdir(self.folder): + return False + + if step == -1: + step_counts = [] + for filename in os.listdir(self.folder): + match = re.search(r"step-(\d+)", filename) + metadata_probe = os.path.join(self.folder, filename, ".metadata") + if match and os.path.isfile(metadata_probe): + step_counts.append(int(match.group(1))) + if not step_counts: + return False + step = max(step_counts) + + checkpoint_id = self._create_checkpoint_id(step) + if not os.path.isdir(checkpoint_id): + return False + + logger.info(f"Loading the checkpoint at step {step}.") + begin = time.monotonic() + + # For the first step, we will only load the model weights. + states = {"model": self.states["model"]} if step == 0 else self.states + states_to_load = { + k: v for k, v in states.items() if k not in self.exclude_from_loading + } + for exclude_key in self.exclude_from_loading: + if exclude_key not in states: + raise ValueError(f"{exclude_key} not found in state_dict.") + dcp.load(states_to_load, checkpoint_id=checkpoint_id) + GarbageCollection.collect("GC collection for checkpoint loading.") + + logger.info( + f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." + ) + return True + + def maybe_wait_for_staging(self) -> None: + if self.enable_staging and self.staging: + if not self.staging_stream.query(): + self.staging_stream.synchronize() + self.staging = False + + if self.sending_to_checkpoint_mp: + # Copy the sync staging result to another process. + def sync_func(): + self.mp_queue_send.put_nowait( + (self.cpu_offload_state_dict, self.staging_id) + ) + + # This may be a faster way to do zero-overhead checkpointing staging + # checkpointing but we need more thorough investigation before + # swithing to this method. + # self.my_thread = threading.Thread(target=func).start() + sync_func() + self.sending_to_checkpoint_mp = False + + def _initialize_states( + self, + states: Dict[str, Any], + dataloader: DataLoader, + model_parts: List[nn.Module], + optimizers: OptimizersContainer, + lr_schedulers: LRSchedulersContainer, + ) -> None: + """ + Note: Pipeline Parallelism and Virtual Stages + + 1. Even for simple PP schedules, there is a separate optimizer each PP rank. + rank0's optimizer would have a param_group[0] which refers to layers.0 in the + original model. rank1's would _also_ have a param_group[0], since it's index based, + but referring to layers.1. + When saving, these collide and one of them is lost. Then when reloading, only one + stage can restore its optimizer states, others will error. + + The solution to this problem is optimizer flattening: it landed in #127071 + and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict' + kwarg to DCP functions called in the OptimizerContainer. + + 2. With complex PP schedules, we have multiple model chunks per pp rank. This + compounds challenge (1) by also requiring us to reason about multiple 'optim' + objects locally. + + We solve this in the Model and Optimizer wrapper classes by flattening the + state dicts from each object into one state dict before saving/loading. + We rely on the individual state_dicts to not collide, which is gauranteed for + the model by correct pipeline splitting and for the optimizer by the flattening + support described in (1). + + 3. LR schedulers also index model states like optimizers and would need to be + flattened properly to support resharding. Unfortunately, the implementations of + different lr_schedulers do not follow a clear pattern like optimizers do, so it's + hard to write a generic 'flattener' utility. + + TODO: This is currently unsolved and needs a fix. + """ + self.states = states + self.states.update( + { + "model": ModelWrapper(model_parts), + "optimizer": optimizers, + "dataloader": dataloader, + "lr_scheduler": lr_schedulers, + } + ) def _create_checkpoint_id(self, step: int) -> str: return os.path.join(self.folder, f"step-{step}") @@ -282,39 +445,14 @@ def _save_last_step(self, curr_step: int) -> None: logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step)) - self.reset() def _should_save(self, curr_step: int, force: bool = False) -> bool: if not self.enable_checkpoint: return False if not force: - if self.interval_type == IntervalType.STEPS and not ( - curr_step % self.interval == 0 - ): + if curr_step % self.interval == 0: return False - if self.interval_type == IntervalType.SECONDS: - time_sync_result = (time.monotonic() - self.begin_time) >= self.interval - self.time_sync_result = torch.tensor(int(time_sync_result)) - if self.time_sync_work is None: - self.time_sync_work = dist.all_reduce( - self.time_sync_result, group=self.pg, async_op=True - ) - return False - elif curr_step % 5 == 4: - self.time_sync_work.wait() - self.time_sync_work = None - time_sync_result = self.time_sync_result.item() - self.time_sync_result = None - if time_sync_result == 0: - return False - else: - return False - - if self.time_sync_work: - self.time_sync_work.wait() - self.time_sync_work = None - self.time_sync_result = None return True @@ -331,15 +469,11 @@ def _async_wait(self) -> None: self.async_future.result() def _async_with_pinned_memory(self, checkpoint_id: str) -> None: - try: - from torch.distributed._state_dict_utils import ( - _copy_state_dict, - _create_cpu_state_dict, - ) - except ImportError as e: - raise ImportError( - "Please install the latest PyTorch nightly to use async checkpointing with pinned memory." - ) from e + self._cpu_staging(checkpoint_id) + self.sending_to_checkpoint_mp = True + + def _cpu_staging(self, checkpoint_id: Optional[str]) -> None: + """Offload state_dict to CPU memory""" state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states) if self.cpu_offload_state_dict is None: logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f") @@ -357,115 +491,6 @@ def _async_with_pinned_memory(self, checkpoint_id: str) -> None: self.staging = True self.staging_id = checkpoint_id - def save(self, curr_step: int, force: bool = False) -> None: - """ - force = True will force the checkpoint to be saved, even if the interval - has not been reached. - This only happens when train_state.step == job_config.training.steps, or - for initial seed checkpoint. - """ - if not self._should_save(curr_step, force): - return - - begin = time.monotonic() - checkpoint_id = self._create_checkpoint_id(curr_step) - self._async_wait() - # This GC is called for async checkpoint as it is useless to do - # GC right after async_save -- the CPU memory is not able to be - # freed until _async_wait() - if force: - self._save_last_step(curr_step) - elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: - GarbageCollection.collect("GC collection invoked by checkpointer.") - self._async_with_pinned_memory(checkpoint_id) - elif self.async_mode == AsyncMode.ASYNC: - GarbageCollection.collect("GC collection invoked by checkpointer.") - self.async_future = dcp.async_save( - self.states, checkpoint_id=checkpoint_id, process_group=self.pg - ) - else: - save_with_gc(self.states, checkpoint_id=checkpoint_id) - self.reset() - self._purge_stale_checkpoints() - - logger.info( - "Finished saving the checkpoint (or staging if async is enabled)" - f"in {time.monotonic() - begin:.2f} seconds." - ) - - def maybe_wait_for_staging(self) -> None: - if ( - self.enable_checkpoint - and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM - and self.staging - ): - if not self.staging_stream.query(): - self.staging_stream.synchronize() - - def sync_func(): - self.mp_queue_send.put_nowait( - (self.cpu_offload_state_dict, self.staging_id) - ) - - # This may be a faster way to do zero-overhead checkpointing staging - # checkpointing but we need more thorough investigation before - # swithing to this method. - # self.my_thread = threading.Thread(target=func).start() - sync_func() - self.staging = False - - def load(self, step: int = -1) -> bool: - if not self.enable_checkpoint: - return False - if not os.path.isdir(self.folder): - return False - if step != -1 and not os.path.isdir(self._create_checkpoint_id(step)): - return False - - if step == -1: - step_counts = [] - for filename in os.listdir(self.folder): - match = re.search(r"step-(\d+)", filename) - metadata_probe = os.path.join(self.folder, filename, ".metadata") - if match and os.path.isfile(metadata_probe): - step_counts.append(int(match.group(1))) - if not step_counts: - return False - step = max(step_counts) - - # We won't have optimizer states to load, if we are loading a seed checkpoint - states = {"model": self.states["model"]} if step == 0 else self.states - # PyTorch bug: (pytorch/pytorch#138575) - # dcp.load() replaces the values of stateful elements in `states` with new objects - # from loading the checkpoint, in addition to updating the states of the original - # objects from `states` in-place. This is a problem because the state_dict no longer - # refers to the objects being used in the train loop, meaning any future checkpoints - # will not include updates to these objects (such as updated optimizer states, etc.) - original_stateful_states = { - k: v for k, v in states.items() if isinstance(v, Stateful) - } - logger.info(f"Loading the checkpoint at step {step}.") - begin = time.monotonic() - states_to_load = { - k: v for k, v in states.items() if k not in self.exclude_from_loading - } - for exclude_key in self.exclude_from_loading: - if exclude_key not in states: - raise ValueError(f"{exclude_key} not found in state_dict.") - dcp.load( - states_to_load, - checkpoint_id=self._create_checkpoint_id(step), - ) - states.update(states_to_load) - logger.info( - f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." - ) - # bugfix from above: restore the original stateful objects, - # whose states were already updated in-place by dcp.load() - states.update(original_stateful_states) - GarbageCollection.collect("GC collection for checkpoint loading.") - return True - def _purge_stale_checkpoints(self): if self.keep_latest_k > 0: discovered_checkpoints = [] diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index db9e29003..273346b01 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -631,6 +631,30 @@ def __init__(self): action="store_true", ) + self.parser.add_argument( + "--experimental.enable_torchft", + action="store_true", + help="Enable TorchFT integration.", + ) + + self.parser.add_argument( + "--experimental.ft_replica_group_id", + type=int, + default=-1, + help="The TorchFT replicate group ID of this run.", + ) + + self.parser.add_argument( + "--experimental.ft_replica_group_size", + type=int, + default=-1, + help=""" + The number of TorchFT replicate groups. This number will be used for + dataloader to split the dataset across the replicate groups and FSDP + dimension. + """, + ) + def to_dict(self): return self.args_dict diff --git a/torchtitan/ft.py b/torchtitan/ft.py new file mode 100644 index 000000000..4a8e0ded3 --- /dev/null +++ b/torchtitan/ft.py @@ -0,0 +1,41 @@ +import importlib +from dataclasses import dataclass +from typing import Optional + +from torchtitan.config_manager import JobConfig + +if importlib.util.find_spec("torchft") is not None: + import torchft as ft + + has_torchft = True +else: + has_torchft = False + + +@dataclass +class FTManager: + manager: ft.Manager + replicate_group_size: int + + +def init_ft_manager(job: JobConfig) -> Optional[FTManager]: + """ + Initialize the FT manager for the given job. + """ + if not job.experimental.enable_torchft: + return None + + if not has_torchft: + raise ImportError("torchft is not installed. Please install it.") + + pg = ft.ProcessGroupBabyNCCL() + manager = ft.Manager( + pg=pg, + min_replica_size=1, + load_state_dict=None, + state_dict=None, + use_async_quorum=True, + replica_id=f"torchtitan_ft_{job.experimental.ft_replica_group_id}", + ) + + return FTManager(manager, job.experimental.ft_replica_group_size) diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index e351fd132..759826d9f 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -6,7 +6,7 @@ import copy import functools -from typing import Any, Callable, Dict, Iterable, List +from typing import Any, Callable, Dict, Iterable, List, Optional import torch import torch.nn as nn @@ -177,8 +177,41 @@ def zero_grad(self) -> None: pass +class FTOptimizersContainer(OptimizersContainer): + def __init__( + self, + model_parts: List[nn.Module], + optimizer_kwargs: Dict[str, Any], + name: str, + ft_manager: Any, + ) -> None: + import torchft as ft + + super().__init__(model_parts, optimizer_kwargs, name) + + # Force to initialize the optimizer state so that `optim.step()` + # won't be called by state_dict() and load_state_dict(). + _ = { + k: v + for sd in map(get_optimizer_state_dict, model_parts, self.optimizers) + for k, v in sd.items() + } + self.optimizers = [ + ft.Optimizer(ft_manager.manager, optim) for optim in self.optimizers + ] + self.cache_state_dict: Dict[str, Any] = {} + + def init_cache_state_dict(self) -> None: + self.cache_state_dict = super().state_dict() + + def state_dict(self) -> Dict[str, Any]: + return self.cache_state_dict + + def build_optimizers( - model_parts: List[nn.Module], job_config: JobConfig + model_parts: List[nn.Module], + job_config: JobConfig, + ft_manager: Optional[Any] = None, ) -> OptimizersContainer: """Create a OptimizersContainer for the given model parts and job config. @@ -213,11 +246,14 @@ def build_optimizers( "foreach": not fused, } - return ( - OptimizersContainer(model_parts, optimizer_kwargs, name) - if not optim_in_bwd - else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name) - ) + if optim_in_bwd and ft_manager: + raise ValueError("TorchFT is not supported with optimizers in backward.") + elif optim_in_bwd: + return OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name) + elif ft_manager: + return FTOptimizersContainer(model_parts, optimizer_kwargs, name, ft_manager) + else: + return OptimizersContainer(model_parts, optimizer_kwargs, name) class LRSchedulersContainer(Stateful): diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index f5e6a0e4c..f85f6226a 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from functools import cached_property +from typing import Any, Optional from torch.distributed.device_mesh import init_device_mesh @@ -24,6 +25,7 @@ class ParallelDims: pp: int world_size: int enable_loss_parallel: bool + ft_manager: Optional[Any] def __post_init__(self): self._validate() @@ -56,13 +58,24 @@ def build_mesh(self, device_type): [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], ["pp", "dp_replicate", "dp_shard", "cp", "tp"], ): - if d > 1: + if d > 1 or (name == "dp_replicate" and self.ft_manager is not None): dims.append(d) names.append(name) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") names = tuple(names) - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + if self.ft_manager is None: + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + else: + from torchft.process_group import ft_init_device_mesh + + mesh = ft_init_device_mesh( + device_type=device_type, + mesh_shape=dims, + mesh_dim_names=names, + replicate_dim=names.index("dp_replicate"), + manager=self.ft_manager.manager, + ) # Create all the submesh here to ensure all required process groups are # initialized: @@ -73,7 +86,7 @@ def build_mesh(self, device_type): # Mesh for loss all-reduce dp_cp_mesh_dim_names = [] - if self.dp_replicate_enabled: + if self.dp_replicate_enabled or ft_manager is not None: dp_mesh_dim_names.append("dp_replicate") dp_cp_mesh_dim_names.append("dp_replicate") if self.dp_shard_enabled: @@ -101,7 +114,7 @@ def dp_enabled(self): @property def dp_replicate_enabled(self): - return self.dp_replicate > 1 + return self.dp_replicate > 1 or self.ft_manager is not None @property def dp_shard_enabled(self): diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 024d1ac57..b7ebbffd1 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +import copy import gc import importlib import math @@ -18,6 +19,7 @@ import torch import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d +import torchft as ft from torch import distributed as dist from torch._utils import _get_available_device_type, _get_device_module from torch.distributed.device_mesh import DeviceMesh @@ -38,6 +40,12 @@ def get_device_info(): def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float: + if isinstance(mesh, ft.process_group._FlattenDeviceMesh): + x = funcol.all_reduce( + x, reduceOp=reduceOp, group=mesh.managed_mesh.replicate_pg + ) + mesh = mesh.managed_mesh.mesh + if isinstance(x, DTensor): # functional collectives do not support DTensor inputs x = x.full_tensor() @@ -407,6 +415,16 @@ def clip_grad_norm_( if isinstance(total_norm, DTensor): # Will reach here if any non-PP parallelism is used. # If only using PP, total_norm will be a local tensor. + mesh = total_norm._spec.mesh + if isinstance(mesh, ft.process_group.ManagedDeviceMesh): + # The gradients along the replicated dim has already been reduced. + # So we don't need another reducution beforing removing the + # replicate dimension + local_tensor = total_norm.to_local() + placements = list(copy.copy(total_norm._spec.placements)) + placements.pop(mesh.replicate_dim) + total_norm = DTensor.from_local(local_tensor, mesh.mesh, placements) + total_norm = total_norm.full_tensor() if pp_mesh is not None: diff --git a/train.py b/train.py index eeb3705f9..6b1c090a0 100644 --- a/train.py +++ b/train.py @@ -16,6 +16,7 @@ from torchtitan.checkpoint import CheckpointManager, TrainState from torchtitan.config_manager import JobConfig from torchtitan.float8 import Float8Handler +from torchtitan.ft import init_ft_manager from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_device_memory_monitor, build_metric_logger from torchtitan.parallelisms import ParallelDims @@ -42,6 +43,10 @@ def main(job_config: JobConfig): # take control of garbage collection to avoid stragglers gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) + device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") + device_module.set_device(device) + ft_manager = init_ft_manager(job_config) + # init distributed world_size = int(os.environ["WORLD_SIZE"]) parallel_dims = ParallelDims( @@ -52,9 +57,8 @@ def main(job_config: JobConfig): pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=not job_config.training.disable_loss_parallel, + ft_manager=ft_manager, ) - device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") - device_module.set_device(device) utils.init_distributed(job_config) # initialize device memory monitor and get peak flops for MFU calculation device_memory_monitor = build_device_memory_monitor() @@ -187,7 +191,7 @@ def loss_fn(pred, labels): ) # build optimizer after applying parallelisms to the model - optimizers = train_spec.build_optimizers_fn(model_parts, job_config) + optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager) lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config) train_state = TrainState() @@ -200,6 +204,7 @@ def loss_fn(pred, labels): lr_schedulers=lr_schedulers, states={"train_state": train_state}, job_config=job_config, + ft_manager=ft_manager, ) if job_config.checkpoint.create_seed_checkpoint: @@ -240,8 +245,6 @@ def loss_fn(pred, labels): time_last_log = time.perf_counter() device_memory_monitor.reset_peak_stats() - checkpoint.reset() - # train loop logger.info( f"Training starts at step {train_state.step + 1}, " diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index c61640362..8b1387717 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -44,7 +44,6 @@ pipeline_parallel_degree = 1 [checkpoint] enable_checkpoint = false folder = "checkpoint" -interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32"