Skip to content

Commit

Permalink
Make sure that tests don't change the state of the compilation cache
Browse files Browse the repository at this point in the history
If it was initialized before the test, it should stay so after. And the other
way around too.

PiperOrigin-RevId: 726899671
  • Loading branch information
apaszke authored and Google-ML-Automation committed Feb 14, 2025
1 parent 49ad241 commit 5ab8c5a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ py_library(
":internal",
] + jax_test_util_visibility,
deps = [
":compilation_cache_internal",
":jax",
] + py_deps("absl/testing") + py_deps("numpy"),
)
Expand Down
24 changes: 15 additions & 9 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import jax
from jax import lax
from jax._src import api
from jax._src import compilation_cache
from jax._src import config
from jax._src import core
from jax._src import deprecations
Expand All @@ -60,7 +61,6 @@
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance)
from jax._src.util import unzip2
from jax.experimental.compilation_cache import compilation_cache
from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten
import numpy as np
import numpy.random as npr
Expand Down Expand Up @@ -1265,17 +1265,23 @@ def __repr__(self):

@contextmanager
def assert_global_configs_unchanged():
starting_cache = compilation_cache._cache
starting_config = jax.config.values.copy()
yield
ending_config = jax.config.values

if starting_config == ending_config:
return
differing = {k: (starting_config.get(k, NotPresent()), ending_config.get(k, NotPresent()))
for k in (starting_config.keys() | ending_config.keys())
if (k not in starting_config or k not in ending_config
or starting_config[k] != ending_config[k])}
raise AssertionError(f"Test changed global config values. Differing values are: {differing}")
ending_cache = compilation_cache._cache

if starting_config != ending_config:
differing = {k: (starting_config.get(k, NotPresent()), ending_config.get(k, NotPresent()))
for k in (starting_config.keys() | ending_config.keys())
if (k not in starting_config or k not in ending_config
or starting_config[k] != ending_config[k])}
raise AssertionError(f"Test changed global config values. Differing values are: {differing}")
if starting_cache is not ending_cache:
raise AssertionError(
f"Test changed the compilation cache object: before test it was "
f"{starting_cache}, now it is {ending_cache}"
)


class JaxTestCase(parameterized.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/pgle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,8 @@ def f(x, y):
# Test pass fdo_profile as compiler_options API works.
f_lowered.compile(compiler_options={'fdo_profile': fdo_profile})


def testPersistentCachePopulatedWithAutoPgle(self):
self.skipTest('Test does not cleanly reset the compilation cache')
its = 50
mesh = jtu.create_mesh((2,), ('x',))

Expand Down

0 comments on commit 5ab8c5a

Please sign in to comment.