Skip to content

Commit

Permalink
[sharding_in_types] Set the sharding_in_types config to True. This …
Browse files Browse the repository at this point in the history
…is a purely internal change and shouldn't affect any public APIs.

PiperOrigin-RevId: 721081589
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Feb 13, 2025
1 parent f0cd168 commit c2179cd
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 32 deletions.
2 changes: 1 addition & 1 deletion jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:

sharding_in_types = bool_state(
name='jax_sharding_in_types',
default=False,
default=True,
help=('When True, enables forward only sharding propagation in JAX and '
'avals have sharding on them.'),
include_in_jit_key=True)
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,12 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
f' `jax.sharding.use_mesh`. Got {sharding}')
sharding = NamedSharding(cur_mesh, sharding) # type: ignore
else:
# There are cases when you have multiple meshes set. Allow that for full
# auto mode because of existing use cases.
# TODO(yashkatariya): Remove this once we disallow different meshes.
if (sharding.mesh.abstract_mesh._are_all_axes_auto and
cur_mesh._are_all_axes_auto):
return sharding
if (check_mesh_consistency and not cur_mesh.empty and
sharding.mesh.abstract_mesh != cur_mesh):
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def f(x):
out_host, np_inp * 2, s_host, 'pinned_host')

def test_output_streaming_inside_scan(self):
self.skipTest("b/393371838")
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
self.skipTest("This test requires an xla_version >= 2.")
mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z"))
Expand Down Expand Up @@ -1330,11 +1331,11 @@ def test_jit_compilation_cache_hit(self):
f = jax.jit(lambda x: x @ x.T)

with (jtu.count_pjit_cpp_cache_miss() as cpp_count,
jtu.count_jit_and_pmap_lowerings() as compile_count):
jtu.count_jit_and_pmap_lowerings() as lowering_count):
f(inp)
f(inp2)
self.assertEqual(cpp_count(), 2)
self.assertEqual(compile_count(), 1)
self.assertEqual(lowering_count(), 2)

def test_jit_cpp_cache_output_hit(self):
_, _, _, inp = _create_inputs((8, 2), P("x"), mem_kind="device")
Expand Down
51 changes: 22 additions & 29 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,7 @@ def test_jit_different_mesh_in_auto(self):
mesh2 = jax.sharding.Mesh([dev[0], dev[3], dev[2], dev[1]], 'x')
f = jax.jit(lambda x, y: (x, y),
in_shardings=(NamedSharding(mesh2, P('x')), AUTO(mesh1)))
inp = core.ShapedArray((8, 2), np.float32)
inp = jax.ShapeDtypeStruct((8, 2), np.float32)
with self.assertRaisesRegex(
ValueError,
"Received incompatible devices for jitted computation"):
Expand Down Expand Up @@ -3518,16 +3518,13 @@ def mul(x):
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)

out3 = jnp.squeeze(arr, axis=-1)
cache_info3 = pxla._cached_compilation.cache_info()
self.assertIsInstance(out3.sharding, NamedSharding)

out4 = jnp.squeeze(arr2, axis=-1)
cache_info4 = pxla._cached_compilation.cache_info()
self.assertIsInstance(out4.sharding, PositionalSharding)
with jtu.count_jit_tracing_cache_miss() as tracing_count:
out3 = jnp.squeeze(arr, axis=-1)
self.assertIsInstance(out3.sharding, NamedSharding)

self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
self.assertEqual(cache_info4.misses, cache_info3.misses)
out4 = jnp.squeeze(arr2, axis=-1)
self.assertIsInstance(out4.sharding, PositionalSharding)
self.assertEqual(tracing_count(), 2)

@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_cache_hit_pjit_lower_with_cpp_cache_miss(self):
Expand All @@ -3551,7 +3548,7 @@ def mul(x):

# Drops out of C++ cache i.e. cache miss
self.assertEqual(count(), 2)
self.assertEqual(lowering_count(), 1)
self.assertEqual(lowering_count(), 2)

def test_list_in_pspec(self):
mesh = jtu.create_mesh((2,), ('x',))
Expand Down Expand Up @@ -3632,22 +3629,21 @@ def test_jit_mul_sum_sharding_preserved(self):
arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps)

f = jax.jit(lambda x: x * 2)
out = f(arr)
cache_info1 = pxla._cached_compilation.cache_info()
self.assertIsInstance(out.sharding, NamedSharding)

with jtu.count_pjit_cpp_cache_miss() as count:
out2 = f(arr2)
cache_info2 = pxla._cached_compilation.cache_info()
self.assertIsInstance(out2.sharding, PositionalSharding)
with jtu.count_jit_compilation_cache_miss() as compilation_count:
out = f(arr)
self.assertIsInstance(out.sharding, NamedSharding)

# This will hit the cpp cache.
out3 = f(out2)
self.assertIsInstance(out3.sharding, PositionalSharding)
self.assertEqual(count(), 1)
with jtu.count_pjit_cpp_cache_miss() as cpp_count:
out2 = f(arr2)
self.assertIsInstance(out2.sharding, PositionalSharding)

self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
# This will hit the cpp cache.
out3 = f(out2)
self.assertIsInstance(out3.sharding, PositionalSharding)

self.assertEqual(compilation_count(), 2)
self.assertEqual(cpp_count(), 1)

out4 = jnp.sum(arr)
self.assertIsInstance(out4.sharding, NamedSharding)
Expand Down Expand Up @@ -3966,14 +3962,11 @@ def f():
f() # doesn't crash

def test_lowering_cache_hit_different_devices(self):
if config.use_shardy_partitioner.value:
self.skipTest('b/358322664: different axis names results in '
'a cache miss with Shardy.')
if jax.device_count() < 4:
self.skipTest('Requires >=4 devices')

mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x')
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'y')
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'x')

@jax.jit
def f(x):
Expand All @@ -3984,7 +3977,7 @@ def g(a):
out_a = f(a) # lowering cached

# same num_devices but different devices.
b = jax.device_put(out_a, NamedSharding(mesh2, P('y')))
b = jax.device_put(out_a, NamedSharding(mesh2, P('x')))
f(b) # lowering cache *hit*

with jtu.count_jit_and_pmap_lowerings() as count:
Expand Down

0 comments on commit c2179cd

Please sign in to comment.