-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
In this PR, only jit and control flows are supported. Support for vmap and multi-device environments will be added in subsequent PRs. PiperOrigin-RevId: 725633039
- Loading branch information
1 parent
4df5961
commit b652849
Showing
4 changed files
with
271 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright 2025 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import threading | ||
|
||
import jax | ||
from jax._src import core | ||
from jax._src import source_info_util | ||
from jax._src import traceback_util | ||
import jax.numpy as jnp | ||
|
||
|
||
Traceback = source_info_util.Traceback | ||
|
||
|
||
traceback_util.register_exclusion(__file__) | ||
|
||
|
||
class JaxValueError(ValueError): | ||
"""Exception raised for failed runtime error checks in JAX.""" | ||
|
||
|
||
_NO_ERROR = jnp.iinfo(jnp.uint32).max | ||
"""The default error code for no error. | ||
We choose this value because when performing reductions, we can use `min` to | ||
obtain the smallest error code. | ||
""" | ||
|
||
|
||
_error_code_ref: core.MutableArray | None = None | ||
_error_list_lock = threading.Lock() | ||
_error_list: list[tuple[str, Traceback]] = [] # (error_message, traceback) pair | ||
|
||
|
||
def _initialize_error_code_ref() -> None: | ||
with core.eval_context(): | ||
global _error_code_ref | ||
error_code = jnp.uint32(_NO_ERROR) | ||
_error_code_ref = core.mutable_array(error_code) | ||
|
||
|
||
def set_error_if(pred: jax.Array, msg: str) -> None: | ||
"""Set error if pred is true. | ||
If the error is already set, the new error will be ignored. It will not | ||
override the existing error. | ||
""" | ||
if _error_code_ref is None: | ||
_initialize_error_code_ref() | ||
assert _error_code_ref is not None | ||
|
||
traceback = source_info_util.current().traceback | ||
assert traceback is not None | ||
with _error_list_lock: | ||
new_error_code = len(_error_list) | ||
_error_list.append((msg, traceback)) | ||
|
||
pred = pred.any() | ||
error_code = _error_code_ref[...] | ||
should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR)) | ||
error_code = jnp.where(should_update, new_error_code, error_code) | ||
# TODO(ayx): support vmap and shard_map. | ||
_error_code_ref[...] = error_code | ||
|
||
|
||
def raise_if_error() -> None: | ||
"""Raise error if an error is set.""" | ||
if _error_code_ref is None: # if not initialized, do nothing | ||
return | ||
|
||
error_code = _error_code_ref[...] | ||
if error_code == jnp.uint32(_NO_ERROR): | ||
return | ||
try: | ||
msg, traceback = _error_list[error_code] | ||
exc = JaxValueError(msg) | ||
traceback = traceback.as_python_traceback() | ||
filtered_traceback = traceback_util.filter_traceback(traceback) | ||
raise exc.with_traceback(filtered_traceback) | ||
finally: | ||
_error_code_ref[...] = jnp.uint32(_NO_ERROR) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# Copyright 2025 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
import jax | ||
from jax._src import config | ||
from jax._src import error_check | ||
from jax._src import test_util as jtu | ||
import jax.numpy as jnp | ||
|
||
|
||
JaxValueError = error_check.JaxValueError | ||
|
||
|
||
config.parse_flags_with_absl() | ||
|
||
|
||
@jtu.with_config(jax_check_tracer_leaks=True) | ||
class ErrorCheckTests(jtu.JaxTestCase): | ||
|
||
@parameterized.product(jit=[True, False]) | ||
def test_error_check(self, jit): | ||
def f(x): | ||
error_check.set_error_if(x <= 0, "x must be greater than 0") | ||
return x + 1 | ||
|
||
if jit: | ||
f = jax.jit(f) | ||
|
||
x = jnp.full((4,), -1, dtype=jnp.int32) | ||
f(x) | ||
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): | ||
error_check.raise_if_error() | ||
|
||
@parameterized.product(jit=[True, False]) | ||
def test_error_check_no_error(self, jit): | ||
def f(x): | ||
error_check.set_error_if(x <= 0, "x must be greater than 0") | ||
return x + 1 | ||
|
||
if jit: | ||
f = jax.jit(f) | ||
|
||
x = jnp.full((4,), 1, dtype=jnp.int32) | ||
f(x) | ||
error_check.raise_if_error() # should not raise error | ||
|
||
@parameterized.product(jit=[True, False]) | ||
def test_error_check_should_report_the_first_error(self, jit): | ||
def f(x): | ||
error_check.set_error_if(x >= 1, "x must be less than 1 in f") | ||
return x + 1 | ||
|
||
def g(x): | ||
error_check.set_error_if(x >= 1, "x must be less than 1 in g") | ||
return x + 1 | ||
|
||
if jit: | ||
f = jax.jit(f) | ||
g = jax.jit(g) | ||
|
||
x = jnp.full((4,), 0, dtype=jnp.int32) | ||
|
||
x = f(x) # check passes, so it should not set error | ||
x = g(x) # check fails. so it should set error | ||
_ = f(x) # check fails, but should not override the error | ||
with self.assertRaisesRegex(JaxValueError, "x must be less than 1 in g"): | ||
error_check.raise_if_error() | ||
|
||
@parameterized.product(jit=[True, False]) | ||
def test_raise_if_error_clears_error(self, jit): | ||
def f(x): | ||
error_check.set_error_if(x <= 0, "x must be greater than 0 in f") | ||
return x + 1 | ||
|
||
def g(x): | ||
error_check.set_error_if(x <= 0, "x must be greater than 0 in g") | ||
return x + 1 | ||
|
||
if jit: | ||
f = jax.jit(f) | ||
g = jax.jit(g) | ||
|
||
x = jnp.full((4,), -1, dtype=jnp.int32) | ||
f(x) | ||
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in f"): | ||
error_check.raise_if_error() | ||
|
||
error_check.raise_if_error() # should not raise error | ||
|
||
g(x) | ||
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in g"): | ||
error_check.raise_if_error() | ||
|
||
@parameterized.product(jit=[True, False]) | ||
def test_error_check_works_with_cond(self, jit): | ||
def f(x): | ||
error_check.set_error_if(x == 0, "x must be non-zero in f") | ||
return x + 1 | ||
|
||
def g(x): | ||
error_check.set_error_if(x == 0, "x must be non-zero in g") | ||
return x + 1 | ||
|
||
def body(pred, x): | ||
return jax.lax.cond(pred, f, g, x) | ||
|
||
if jit: | ||
body = jax.jit(body) | ||
|
||
x = jnp.zeros((4,), dtype=jnp.int32) | ||
|
||
_ = body(jnp.bool_(True), x) | ||
with self.assertRaisesRegex(JaxValueError, "x must be non-zero in f"): | ||
error_check.raise_if_error() | ||
|
||
_ = body(jnp.bool_(False), x) | ||
with self.assertRaisesRegex(JaxValueError, "x must be non-zero in g"): | ||
error_check.raise_if_error() | ||
|
||
@parameterized.product(jit=[True, False]) | ||
def test_error_check_works_with_while_loop(self, jit): | ||
def f(x): | ||
error_check.set_error_if(x >= 10, "x must be less than 10") | ||
return x + 1 | ||
|
||
def body(x): | ||
return jax.lax.while_loop(lambda x: (x < 10).any(), f, x) | ||
|
||
if jit: | ||
body = jax.jit(body) | ||
|
||
x = jnp.arange(4, dtype=jnp.int32) | ||
_ = body(x) | ||
with self.assertRaisesRegex(JaxValueError, "x must be less than 10"): | ||
error_check.raise_if_error() | ||
|
||
def test_error_check_works_with_scan(self): | ||
def f(carry, x): | ||
error_check.set_error_if(x >= 4, "x must be less than 4") | ||
return carry + x, x + 1 | ||
|
||
def body(init, xs): | ||
return jax.lax.scan(f, init=init, xs=xs) | ||
|
||
init = jnp.int32(0) | ||
xs = jnp.arange(5, dtype=jnp.int32) | ||
_ = body(init, xs) | ||
with self.assertRaisesRegex(JaxValueError, "x must be less than 4"): | ||
error_check.raise_if_error() | ||
|
||
xs = jnp.arange(4, dtype=jnp.int32) | ||
_ = body(init, xs) | ||
error_check.raise_if_error() # should not raise error | ||
|
||
if __name__ == "__main__": | ||
absltest.main(testLoader=jtu.JaxTestLoader()) |