Skip to content

Commit

Permalink
#sdy unskip JAX Shardy tests that are already passing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718786627
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Jan 23, 2025
1 parent 4222c30 commit af8ee13
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 14 deletions.
4 changes: 0 additions & 4 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,6 @@ def test_persistent_cache_hit_no_logging(self):
self.assertFalse(msg_exists_in_logs(msg, log.records, logging.WARNING))

def test_persistent_cache_miss_logging_with_explain(self):
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
with (config.explain_cache_misses(True),
config.compilation_cache_dir("jax-cache")):

Expand Down Expand Up @@ -502,8 +500,6 @@ def test_persistent_cache_miss_logging_with_explain(self):

def test_persistent_cache_miss_logging_with_no_explain(self):
# test that cache failure messages do not get logged in WARNING
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
with (config.explain_cache_misses(False),
config.compilation_cache_dir("jax-cache")):
# omitting writing to cache because compilation is too fast
Expand Down
6 changes: 0 additions & 6 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import jax
from jax import lax
from jax import random
from jax._src import config
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import state
Expand Down Expand Up @@ -1416,9 +1415,6 @@ def test_debug_print(self):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Test for TPU is covered in tpu_pallas_test.py")

if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")

# TODO: this test flakes on gpu
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test flakes on gpu")
Expand Down Expand Up @@ -2254,8 +2250,6 @@ class OpsInterpretTest(OpsTest):
INTERPRET = True

def test_debug_print(self):
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
Expand Down
6 changes: 2 additions & 4 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,6 @@ def f(inp1, inp2, inp3):

@jtu.run_on_devices('tpu')
def testBufferDonationWithOutputShardingInferenceAndTokens(self):
if config.use_shardy_partitioner.value:
self.skipTest('b/355263220: Shardy does not support callbacks yet.')
mesh = jtu.create_mesh((2,), 'x')
s = NamedSharding(mesh, P('x'))

Expand Down Expand Up @@ -4312,7 +4310,7 @@ def f(x):

def test_empty_io_callback_under_shard_map(self):
if config.use_shardy_partitioner.value:
self.skipTest("Shardy errors out on empty callbacks.")
self.skipTest("TODO(b/384938613): Failing under shardy.")
mesh = jtu.create_mesh((4,), 'i')

def empty_callback(x):
Expand All @@ -4330,7 +4328,7 @@ def _f(x, y):

def test_empty_io_callback_under_shard_map_reshard_to_singledev(self):
if config.use_shardy_partitioner.value:
self.skipTest("Shardy errors out on empty callbacks.")
self.skipTest("TODO(b/384938613): Failing under shardy.")
mesh = jtu.create_mesh((4,), 'i')

def empty_callback(x):
Expand Down

0 comments on commit af8ee13

Please sign in to comment.