diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index da6f2b397..15b3d0a87 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -156,41 +156,6 @@ def __init__( if not self.enable_checkpoint and self.ft_manager is None: return -<<<<<<< HEAD - 1. even for simple PP schedules, there is a separate optimizer each PP rank. - rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model. - rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1. - When saving, these collide and one of them is lost. Then when reloading, only one stage can - restore its optimizer states, others will error. - - The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan - by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer. - - 2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also - requiring us to reason about multiple 'optim' objects locally. - - We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object - into one state dict before saving/loading. We rely on the individual state_dicts to not collide, - which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening - support described in (1). - - 3. LR schedulers also index model states like optimizers and would need to be flattened properly to support - resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like - optimizers do, so it's hard to write a generic 'flattener' utility. - - TODO: This is currently unsolved and needs a fix. - """ - self.states = states - - self.states.update( - { - "model": ModelWrapper(model_parts), - "optimizer": optimizers, - "dataloader": dataloader, - "lr_scheduler": lr_schedulers, - } - ) -======= self._initialize_states( states, dataloader, model_parts, optimizers, lr_schedulers ) @@ -201,7 +166,6 @@ def __init__( self.staging_id = None self.cpu_offload_state_dict = None self.staging_stream = torch.cuda.Stream() if self.enable_staging else None ->>>>>>> 3430d99 ([WIP][RFC] TorchFT integration) self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) self.interval_type = ( @@ -305,9 +269,9 @@ def _initialize_states( "model": ModelWrapper(model_parts), "optimizer": optimizers, "dataloader": dataloader, + "lr_scheduler": lr_schedulers, } ) - self.states.update(lr_schedulers.get_lr_scheduler_state()) def _create_checkpoint_id(self, step: int) -> str: return os.path.join(self.folder, f"step-{step}") diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 5e4078326..bee91f0f1 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -408,9 +408,10 @@ def clip_grad_norm_( # If only using PP, total_norm will be a local tensor. mesh = total_norm._spec.mesh if isinstance(mesh, ft.process_group.ManagedDeviceMesh): + # The gradients along the replicated dim has been reduced. + # So we don't need another reducution beforing removing the + # replicate dimension local_tensor = total_norm.to_local() - dist.all_reduce(local_tensor, op=dist.ReduceOp.AVG, group=mesh.replicate_pg) - placements = list(copy.copy(total_norm._spec.placements)) placements.pop(mesh.replicate_dim) mesh = mesh.mesh