Skip to content

Commit

Permalink
Add JAX error checking support
Browse files Browse the repository at this point in the history
In this PR, only jit and control flows are supported. Support for vmap and multi-device environments will be added in subsequent PRs.

PiperOrigin-RevId: 725633039
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Feb 14, 2025
1 parent 4df5961 commit b652849
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 0 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ py_library_providing_imports_info(
"_src/dispatch.py",
"_src/dlpack.py",
"_src/earray.py",
"_src/error_check.py",
"_src/ffi.py",
"_src/flatten_util.py",
"_src/interpreters/__init__.py",
Expand Down
95 changes: 95 additions & 0 deletions jax/_src/error_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import threading

import jax
from jax._src import core
from jax._src import source_info_util
from jax._src import traceback_util
import jax.numpy as jnp


Traceback = source_info_util.Traceback


traceback_util.register_exclusion(__file__)


class JaxValueError(ValueError):
"""Exception raised for failed runtime error checks in JAX."""


_NO_ERROR = jnp.iinfo(jnp.uint32).max
"""The default error code for no error.
We choose this value because when performing reductions, we can use `min` to
obtain the smallest error code.
"""


_error_code_ref: core.MutableArray | None = None
_error_list_lock = threading.Lock()
_error_list: list[tuple[str, Traceback]] = [] # (error_message, traceback) pair


def _initialize_error_code_ref() -> None:
with core.eval_context():
global _error_code_ref
error_code = jnp.uint32(_NO_ERROR)
_error_code_ref = core.mutable_array(error_code)


def set_error_if(pred: jax.Array, msg: str) -> None:
"""Set error if pred is true.
If the error is already set, the new error will be ignored. It will not
override the existing error.
"""
if _error_code_ref is None:
_initialize_error_code_ref()
assert _error_code_ref is not None

traceback = source_info_util.current().traceback
assert traceback is not None
with _error_list_lock:
new_error_code = len(_error_list)
_error_list.append((msg, traceback))

pred = pred.any()
error_code = _error_code_ref[...]
should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR))
error_code = jnp.where(should_update, new_error_code, error_code)
# TODO(ayx): support vmap and shard_map.
_error_code_ref[...] = error_code


def raise_if_error() -> None:
"""Raise error if an error is set."""
if _error_code_ref is None: # if not initialized, do nothing
return

error_code = _error_code_ref[...]
if error_code == jnp.uint32(_NO_ERROR):
return
try:
msg, traceback = _error_list[error_code]
exc = JaxValueError(msg)
traceback = traceback.as_python_traceback()
filtered_traceback = traceback_util.filter_traceback(traceback)
raise exc.with_traceback(filtered_traceback)
finally:
_error_code_ref[...] = jnp.uint32(_NO_ERROR)
5 changes: 5 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,11 @@ jax_multiplatform_test(
},
)

jax_multiplatform_test(
name = "error_check_test",
srcs = ["error_check_test.py"],
)

jax_multiplatform_test(
name = "stax_test",
srcs = ["stax_test.py"],
Expand Down
170 changes: 170 additions & 0 deletions tests/error_check_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import error_check
from jax._src import test_util as jtu
import jax.numpy as jnp


JaxValueError = error_check.JaxValueError


config.parse_flags_with_absl()


@jtu.with_config(jax_check_tracer_leaks=True)
class ErrorCheckTests(jtu.JaxTestCase):

@parameterized.product(jit=[True, False])
def test_error_check(self, jit):
def f(x):
error_check.set_error_if(x <= 0, "x must be greater than 0")
return x + 1

if jit:
f = jax.jit(f)

x = jnp.full((4,), -1, dtype=jnp.int32)
f(x)
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"):
error_check.raise_if_error()

@parameterized.product(jit=[True, False])
def test_error_check_no_error(self, jit):
def f(x):
error_check.set_error_if(x <= 0, "x must be greater than 0")
return x + 1

if jit:
f = jax.jit(f)

x = jnp.full((4,), 1, dtype=jnp.int32)
f(x)
error_check.raise_if_error() # should not raise error

@parameterized.product(jit=[True, False])
def test_error_check_should_report_the_first_error(self, jit):
def f(x):
error_check.set_error_if(x >= 1, "x must be less than 1 in f")
return x + 1

def g(x):
error_check.set_error_if(x >= 1, "x must be less than 1 in g")
return x + 1

if jit:
f = jax.jit(f)
g = jax.jit(g)

x = jnp.full((4,), 0, dtype=jnp.int32)

x = f(x) # check passes, so it should not set error
x = g(x) # check fails. so it should set error
_ = f(x) # check fails, but should not override the error
with self.assertRaisesRegex(JaxValueError, "x must be less than 1 in g"):
error_check.raise_if_error()

@parameterized.product(jit=[True, False])
def test_raise_if_error_clears_error(self, jit):
def f(x):
error_check.set_error_if(x <= 0, "x must be greater than 0 in f")
return x + 1

def g(x):
error_check.set_error_if(x <= 0, "x must be greater than 0 in g")
return x + 1

if jit:
f = jax.jit(f)
g = jax.jit(g)

x = jnp.full((4,), -1, dtype=jnp.int32)
f(x)
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in f"):
error_check.raise_if_error()

error_check.raise_if_error() # should not raise error

g(x)
with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in g"):
error_check.raise_if_error()

@parameterized.product(jit=[True, False])
def test_error_check_works_with_cond(self, jit):
def f(x):
error_check.set_error_if(x == 0, "x must be non-zero in f")
return x + 1

def g(x):
error_check.set_error_if(x == 0, "x must be non-zero in g")
return x + 1

def body(pred, x):
return jax.lax.cond(pred, f, g, x)

if jit:
body = jax.jit(body)

x = jnp.zeros((4,), dtype=jnp.int32)

_ = body(jnp.bool_(True), x)
with self.assertRaisesRegex(JaxValueError, "x must be non-zero in f"):
error_check.raise_if_error()

_ = body(jnp.bool_(False), x)
with self.assertRaisesRegex(JaxValueError, "x must be non-zero in g"):
error_check.raise_if_error()

@parameterized.product(jit=[True, False])
def test_error_check_works_with_while_loop(self, jit):
def f(x):
error_check.set_error_if(x >= 10, "x must be less than 10")
return x + 1

def body(x):
return jax.lax.while_loop(lambda x: (x < 10).any(), f, x)

if jit:
body = jax.jit(body)

x = jnp.arange(4, dtype=jnp.int32)
_ = body(x)
with self.assertRaisesRegex(JaxValueError, "x must be less than 10"):
error_check.raise_if_error()

def test_error_check_works_with_scan(self):
def f(carry, x):
error_check.set_error_if(x >= 4, "x must be less than 4")
return carry + x, x + 1

def body(init, xs):
return jax.lax.scan(f, init=init, xs=xs)

init = jnp.int32(0)
xs = jnp.arange(5, dtype=jnp.int32)
_ = body(init, xs)
with self.assertRaisesRegex(JaxValueError, "x must be less than 4"):
error_check.raise_if_error()

xs = jnp.arange(4, dtype=jnp.int32)
_ = body(init, xs)
error_check.raise_if_error() # should not raise error

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit b652849

Please sign in to comment.