Skip to content

Commit

Permalink
[WIP][RFC] Required changes for integration with TorchTitan
Browse files Browse the repository at this point in the history
Summary:
We are not going to land this PR, this PR may be further divided into several PRs.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
fegin committed Jan 27, 2025
1 parent bed29d2 commit 768d014
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 13 deletions.
6 changes: 4 additions & 2 deletions torchft/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ def do_GET(self):
self.end_headers()

sd = state_dict()

logger.warning("After state_dict ===================.")
torch.save(sd, self.wfile)
logger.warning("After torch.save ===================.")

except Exception as e:
logger.exception(
f"Exception in checkpoint server when handling {self.path=}: {e}",
Expand Down Expand Up @@ -113,7 +115,7 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T:
data = f.read()

reader = io.BytesIO(data)
return torch.load(reader, weights_only=True)
return torch.load(reader, weights_only=False)

def address(self) -> str:
"""
Expand Down
13 changes: 10 additions & 3 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ class Manager:
def __init__(
self,
pg: "ProcessGroup",
load_state_dict: Callable[[T], None],
state_dict: Callable[[], T],
load_state_dict: Optional[Callable[[T], None]],
state_dict: Optional[Callable[[], T]],
min_replica_size: int,
use_async_quorum: bool = True,
timeout: timedelta = timedelta(seconds=60),
Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__(

def _manager_state_dict() -> Dict[str, T]:
return {
"user": state_dict(),
"user": self._state_dict(),
"torchft": cast(T, self.state_dict()),
}

Expand Down Expand Up @@ -223,6 +223,12 @@ def _manager_state_dict() -> Dict[str, T]:
self._participating_rank: Optional[int] = None
self._participating_world_size: int = 0

def set_state_dict_fns(
self, load_state_dict: Callable[T, None], state_dict: Callable[[], T]
) -> None:
self._load_state_dict = load_state_dict
self._state_dict = state_dict

def shutdown(self) -> None:
"""
Shutdown the manager and checkpoint server.
Expand Down Expand Up @@ -506,6 +512,7 @@ def _apply_pending_state_dict(self) -> None:
assert self._pending_state_dict is not None, "checkpoint was not staged"
self._load_state_dict(self._pending_state_dict["user"])
self._pending_state_dict = None
self._logger.info("Loaded state dict.")

def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
"""
Expand Down
10 changes: 9 additions & 1 deletion torchft/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""

from typing import TYPE_CHECKING, Optional
from typing import Any, TYPE_CHECKING, Optional

from torch.optim import Optimizer

Expand Down Expand Up @@ -52,3 +52,11 @@ def step(self, closure: Optional[object] = None) -> None:
assert closure is None, "optimizers that use closures are not supported"
if self.manager.should_commit():
self.optim.step()

@property
def param_groups(self) -> Any:
return self.optim.param_groups

@property
def state(self) -> Any:
return self.optim.state
34 changes: 27 additions & 7 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import threading
from abc import ABC
from datetime import timedelta
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
from typing import Any, TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -852,6 +852,8 @@ def extend_device_mesh(


class ManagedDeviceMesh(DeviceMesh):
replicate_pg_singleton: Optional["ManagedProcessGroup"]

def __init__(
self,
mesh: Optional[DeviceMesh],
Expand Down Expand Up @@ -880,6 +882,15 @@ def __init__(
self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple()
self._thread_id: Optional[int] = None

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state["replicate_pg"] = None
return state

def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
self.replicate_pg = self.replicate_pg_singleton

def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
if isinstance(mesh_dim_names, str):
if mesh_dim_names == self.replicate_dim_name:
Expand All @@ -897,13 +908,14 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
return self.mesh[mesh_dim_names]
else:
assert isinstance(mesh_dim_names, tuple)
if self.replicate_dim_name in mesh_dim_names:
if self.replicate_dim_name not in mesh_dim_names:
assert self.mesh is not None
return self.mesh[mesh_dim_names]
else:
assert self.mesh is not None
mesh_dim_names_wo_replicate = tuple(n for n in mesh_dim_names if n != self.replicate_dim_name)
return ManagedDeviceMesh(
self.mesh[mesh_dim_names],
self.mesh[mesh_dim_names_wo_replicate],
mesh_dim_names,
self.replicate_pg,
mesh_dim_names.index(self.replicate_dim_name),
Expand Down Expand Up @@ -938,14 +950,16 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
return flatten_mesh

def size(self, mesh_dim: Optional[int] = None) -> int:
replicate_pg_size = self.replicate_pg.size()
replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size
if mesh_dim is None:
if self.mesh is None:
return self.replicate_pg.size()
return replicate_pg_size
else:
assert self.mesh is not None
return self.mesh.size() * self.replicate_pg.size()
return self.mesh.size() * replicate_pg_size
elif mesh_dim == self.replicate_dim:
return self.replicate_pg.size()
return replicate_pg_size
else:
assert self.mesh is not None
return self.mesh.size(self._real_mesh_dim(mesh_dim))
Expand Down Expand Up @@ -995,7 +1009,11 @@ def get_coordinate(self) -> Optional[List[int]]:
dimensions of the mesh. If this rank is not part of the mesh, return None.
"""
assert self.mesh is not None
return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
ret = self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
if ret:
ret = ret.copy()
ret.insert(get_rank(self.replicate_pg), self.replicate_dim)
return ret

def get_all_groups(self) -> List[BaseProcessGroup]:
raise NotImplementedError
Expand Down Expand Up @@ -1070,6 +1088,8 @@ def ft_init_device_mesh(
# the same backend has been registered.
replicate_pg.register(mesh_dim_names[replicate_dim])

ManagedDeviceMesh.replicate_pg_singleton = replicate_pg

return ManagedDeviceMesh(
mesh=mesh,
mesh_dim_names=mesh_dim_names,
Expand Down

0 comments on commit 768d014

Please sign in to comment.