From b652849e86deb920c32db18fe73dd72110c177eb Mon Sep 17 00:00:00 2001 From: Ayaka Date: Tue, 11 Feb 2025 08:03:05 -0800 Subject: [PATCH] Add JAX error checking support In this PR, only jit and control flows are supported. Support for vmap and multi-device environments will be added in subsequent PRs. PiperOrigin-RevId: 725633039 --- jax/BUILD | 1 + jax/_src/error_check.py | 95 +++++++++++++++++++++ tests/BUILD | 5 ++ tests/error_check_test.py | 170 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 271 insertions(+) create mode 100644 jax/_src/error_check.py create mode 100644 tests/error_check_test.py diff --git a/jax/BUILD b/jax/BUILD index 4c57dafc3cbb..5ec0d5dd213e 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -209,6 +209,7 @@ py_library_providing_imports_info( "_src/dispatch.py", "_src/dlpack.py", "_src/earray.py", + "_src/error_check.py", "_src/ffi.py", "_src/flatten_util.py", "_src/interpreters/__init__.py", diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py new file mode 100644 index 000000000000..74326e582e12 --- /dev/null +++ b/jax/_src/error_check.py @@ -0,0 +1,95 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import threading + +import jax +from jax._src import core +from jax._src import source_info_util +from jax._src import traceback_util +import jax.numpy as jnp + + +Traceback = source_info_util.Traceback + + +traceback_util.register_exclusion(__file__) + + +class JaxValueError(ValueError): + """Exception raised for failed runtime error checks in JAX.""" + + +_NO_ERROR = jnp.iinfo(jnp.uint32).max +"""The default error code for no error. + +We choose this value because when performing reductions, we can use `min` to +obtain the smallest error code. +""" + + +_error_code_ref: core.MutableArray | None = None +_error_list_lock = threading.Lock() +_error_list: list[tuple[str, Traceback]] = [] # (error_message, traceback) pair + + +def _initialize_error_code_ref() -> None: + with core.eval_context(): + global _error_code_ref + error_code = jnp.uint32(_NO_ERROR) + _error_code_ref = core.mutable_array(error_code) + + +def set_error_if(pred: jax.Array, msg: str) -> None: + """Set error if pred is true. + + If the error is already set, the new error will be ignored. It will not + override the existing error. + """ + if _error_code_ref is None: + _initialize_error_code_ref() + assert _error_code_ref is not None + + traceback = source_info_util.current().traceback + assert traceback is not None + with _error_list_lock: + new_error_code = len(_error_list) + _error_list.append((msg, traceback)) + + pred = pred.any() + error_code = _error_code_ref[...] + should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR)) + error_code = jnp.where(should_update, new_error_code, error_code) + # TODO(ayx): support vmap and shard_map. + _error_code_ref[...] = error_code + + +def raise_if_error() -> None: + """Raise error if an error is set.""" + if _error_code_ref is None: # if not initialized, do nothing + return + + error_code = _error_code_ref[...] + if error_code == jnp.uint32(_NO_ERROR): + return + try: + msg, traceback = _error_list[error_code] + exc = JaxValueError(msg) + traceback = traceback.as_python_traceback() + filtered_traceback = traceback_util.filter_traceback(traceback) + raise exc.with_traceback(filtered_traceback) + finally: + _error_code_ref[...] = jnp.uint32(_NO_ERROR) diff --git a/tests/BUILD b/tests/BUILD index c957bc94c97d..e56c5493d877 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1117,6 +1117,11 @@ jax_multiplatform_test( }, ) +jax_multiplatform_test( + name = "error_check_test", + srcs = ["error_check_test.py"], +) + jax_multiplatform_test( name = "stax_test", srcs = ["stax_test.py"], diff --git a/tests/error_check_test.py b/tests/error_check_test.py new file mode 100644 index 000000000000..8ac435cbb351 --- /dev/null +++ b/tests/error_check_test.py @@ -0,0 +1,170 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import error_check +from jax._src import test_util as jtu +import jax.numpy as jnp + + +JaxValueError = error_check.JaxValueError + + +config.parse_flags_with_absl() + + +@jtu.with_config(jax_check_tracer_leaks=True) +class ErrorCheckTests(jtu.JaxTestCase): + + @parameterized.product(jit=[True, False]) + def test_error_check(self, jit): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + if jit: + f = jax.jit(f) + + x = jnp.full((4,), -1, dtype=jnp.int32) + f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_error_check_no_error(self, jit): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + if jit: + f = jax.jit(f) + + x = jnp.full((4,), 1, dtype=jnp.int32) + f(x) + error_check.raise_if_error() # should not raise error + + @parameterized.product(jit=[True, False]) + def test_error_check_should_report_the_first_error(self, jit): + def f(x): + error_check.set_error_if(x >= 1, "x must be less than 1 in f") + return x + 1 + + def g(x): + error_check.set_error_if(x >= 1, "x must be less than 1 in g") + return x + 1 + + if jit: + f = jax.jit(f) + g = jax.jit(g) + + x = jnp.full((4,), 0, dtype=jnp.int32) + + x = f(x) # check passes, so it should not set error + x = g(x) # check fails. so it should set error + _ = f(x) # check fails, but should not override the error + with self.assertRaisesRegex(JaxValueError, "x must be less than 1 in g"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_raise_if_error_clears_error(self, jit): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0 in f") + return x + 1 + + def g(x): + error_check.set_error_if(x <= 0, "x must be greater than 0 in g") + return x + 1 + + if jit: + f = jax.jit(f) + g = jax.jit(g) + + x = jnp.full((4,), -1, dtype=jnp.int32) + f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in f"): + error_check.raise_if_error() + + error_check.raise_if_error() # should not raise error + + g(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in g"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_error_check_works_with_cond(self, jit): + def f(x): + error_check.set_error_if(x == 0, "x must be non-zero in f") + return x + 1 + + def g(x): + error_check.set_error_if(x == 0, "x must be non-zero in g") + return x + 1 + + def body(pred, x): + return jax.lax.cond(pred, f, g, x) + + if jit: + body = jax.jit(body) + + x = jnp.zeros((4,), dtype=jnp.int32) + + _ = body(jnp.bool_(True), x) + with self.assertRaisesRegex(JaxValueError, "x must be non-zero in f"): + error_check.raise_if_error() + + _ = body(jnp.bool_(False), x) + with self.assertRaisesRegex(JaxValueError, "x must be non-zero in g"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_error_check_works_with_while_loop(self, jit): + def f(x): + error_check.set_error_if(x >= 10, "x must be less than 10") + return x + 1 + + def body(x): + return jax.lax.while_loop(lambda x: (x < 10).any(), f, x) + + if jit: + body = jax.jit(body) + + x = jnp.arange(4, dtype=jnp.int32) + _ = body(x) + with self.assertRaisesRegex(JaxValueError, "x must be less than 10"): + error_check.raise_if_error() + + def test_error_check_works_with_scan(self): + def f(carry, x): + error_check.set_error_if(x >= 4, "x must be less than 4") + return carry + x, x + 1 + + def body(init, xs): + return jax.lax.scan(f, init=init, xs=xs) + + init = jnp.int32(0) + xs = jnp.arange(5, dtype=jnp.int32) + _ = body(init, xs) + with self.assertRaisesRegex(JaxValueError, "x must be less than 4"): + error_check.raise_if_error() + + xs = jnp.arange(4, dtype=jnp.int32) + _ = body(init, xs) + error_check.raise_if_error() # should not raise error + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())