Skip to content

Commit

Permalink
Split NamedSharding into a separate file called named_sharding.py so …
Browse files Browse the repository at this point in the history
…that we can import it in core.py and break the cyclic dependency.

PiperOrigin-RevId: 726566863
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Feb 13, 2025
1 parent ea4e324 commit 229aa65
Show file tree
Hide file tree
Showing 6 changed files with 603 additions and 555 deletions.
15 changes: 15 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ py_library_providing_imports_info(
":mesh",
":mlir",
":monitoring",
":named_sharding",
":op_shardings",
":partial_eval",
":partition_spec",
Expand Down Expand Up @@ -467,6 +468,7 @@ pytype_strict_library(
":dtypes",
":effects",
":mesh",
":named_sharding",
":partition_spec",
":pretty_printer",
":source_info_util",
Expand Down Expand Up @@ -897,6 +899,7 @@ pytype_strict_library(
":core",
":internal_mesh_utils",
":mesh",
":named_sharding",
":op_shardings",
":partition_spec",
":sharding",
Expand All @@ -909,6 +912,18 @@ pytype_strict_library(
] + py_deps("numpy"),
)

pytype_strict_library(
name = "named_sharding",
srcs = ["_src/named_sharding.py"],
deps = [
":mesh",
":partition_spec",
":sharding",
":util",
"//jax/_src/lib",
] + py_deps("numpy"),
)

pytype_strict_library(
name = "sharding_specs",
srcs = ["_src/sharding_specs.py"],
Expand Down
10 changes: 2 additions & 8 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
HashableFunction, HashableWrapper, weakref_lru_cache,
partition_list, StrictABCMeta)
import jax._src.pretty_printer as pp
from jax._src.named_sharding import NamedSharding
from jax._src.lib import jax_jit
from jax._src.lib import xla_client
from jax._src import traceback_util
Expand Down Expand Up @@ -1491,7 +1492,6 @@ def check_valid_jaxtype(x):
f"Value {x!r} of type {type(x)} is not a valid JAX type")

def update_aval_with_sharding(aval, sharding):
from jax._src.sharding_impls import NamedSharding # type: ignore
if config.sharding_in_types.value and isinstance(sharding, NamedSharding):
aval = aval.update(sharding=NamedSharding(
sharding.mesh.abstract_mesh,
Expand Down Expand Up @@ -1757,8 +1757,6 @@ def canonicalize_value(val):
if not config.sharding_in_types.value:
return val

from jax._src.pjit import NamedSharding, mesh_cast # type: ignore

try:
aval = get_aval(val)
except TypeError:
Expand All @@ -1772,15 +1770,14 @@ def canonicalize_value(val):
if cur_mesh == aval.sharding.mesh: # type: ignore
return val
if cur_mesh._are_all_axes_manual and aval.sharding.mesh._are_all_axes_auto: # type: ignore
from jax._src.pjit import mesh_cast # type: ignore
return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) # type: ignore
return val


def get_cur_mesh_sharding(spec=None):
if not config.sharding_in_types.value:
return None

from jax._src.sharding_impls import NamedSharding # type: ignore
spec = P() if spec is None else spec
return NamedSharding(mesh_lib.get_abstract_mesh(), spec)

Expand Down Expand Up @@ -1819,8 +1816,6 @@ def _maybe_modify_sharding(sharding, ndim):


def get_sharding(sharding, ndim):
from jax._src.sharding_impls import NamedSharding # type: ignore

if sharding is None:
return NamedSharding(mesh_lib.empty_abstract_mesh, P(*[None] * ndim))

Expand Down Expand Up @@ -1972,7 +1967,6 @@ def update(self, shape=None, dtype=None, weak_type=None):

@property
def sharding(self):
from jax._src.sharding_impls import NamedSharding # type: ignore
return NamedSharding(mesh_lib.empty_abstract_mesh, P())

def _len(self, tracer):
Expand Down
21 changes: 20 additions & 1 deletion jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from jax._src import mesh as mesh_lib
from jax._src.util import safe_zip
from jax._src.partition_spec import PartitionSpec as P
from jax._src.named_sharding import NamedSharding

zip, unsafe_zip = safe_zip, zip

Expand All @@ -48,9 +49,27 @@ def standard_primitive(shape_rule, dtype_rule, name,

def _get_array_abstraction_level(a): return a.array_abstraction_level

def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh:
if not config.sharding_in_types.value:
return None # type: ignore
m = None
for a in in_avals:
if a is core.abstract_token:
continue
if a.sharding.mesh.empty: # type: ignore
continue
if m is not None and m != a.sharding.mesh:
if m._are_all_axes_auto and a.sharding.mesh._are_all_axes_auto:
return mesh_lib.empty_abstract_mesh
raise ValueError(
f'Mesh for all inputs should be equal. Got one mesh: {m} and'
f' another mesh: {a.sharding.mesh}')
m = a.sharding.mesh # type: ignore
return mesh_lib.empty_abstract_mesh if m is None else m


def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
if config.sharding_in_types.value:
from jax._src.pjit import _get_abstract_mesh_from_avals, NamedSharding
cur_mesh = mesh_lib.get_abstract_mesh()
aval_mesh = _get_abstract_mesh_from_avals(avals)
if ((cur_mesh.empty or cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual) and
Expand Down
Loading

0 comments on commit 229aa65

Please sign in to comment.