Skip to content

Commit

Permalink
[better_errors] Add debug info to the Jaxprs formed for AD
Browse files Browse the repository at this point in the history
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
  • Loading branch information
gnecula committed Feb 5, 2025
1 parent 414449e commit 5b9b852
Show file tree
Hide file tree
Showing 22 changed files with 481 additions and 169 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ pytype_strict_library(
srcs = ["_src/interpreters/mlir.py"],
deps = [
":ad_util",
":api_util",
":config",
":core",
":dtypes",
Expand Down
10 changes: 7 additions & 3 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,12 +682,13 @@ def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool],
return _transpose_jaxpr(jaxpr, tuple(in_linear), tuple(out_zeros))

@weakref_lru_cache
def _transpose_jaxpr(jaxpr, in_lin, out_zeros):
def _transpose_jaxpr(jaxpr: core.ClosedJaxpr,
in_lin: Sequence[bool],
out_zeros: Sequence[bool]):
in_avals = ([a for a, lin in zip(jaxpr.in_avals, in_lin ) if not lin] +
[a for a, zero in zip(jaxpr.out_avals, out_zeros) if not zero])
cell = lambda: None

@lu.wrap_init
def transposed(*args_flat):
ins_flat, out_cts_flat = split_list(args_flat, [len(in_lin) - sum(in_lin)])

Expand Down Expand Up @@ -715,7 +716,10 @@ def transposed(*args_flat):
in_cts_nz, _ = partition_list(in_zeros, in_cts)
return in_cts_nz

transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
transposed_wrapped = lu.wrap_init(transposed,
debug_info=jaxpr.jaxpr.debug_info)
transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(
transposed_wrapped, in_avals)
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error

Expand Down
22 changes: 13 additions & 9 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ def vmap_f(*args, **kwargs):
"to the positional arguments passed to the function, "
f"but got {len(in_axes)=}, {len(args)=}")
args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable)
f = lu.wrap_init(fun)
f = lu.wrap_init(fun, debug_info=debug_info("vmap", fun, args, kwargs))
flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree)
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
axis_size_ = (axis_size if axis_size is not None else
Expand Down Expand Up @@ -1715,15 +1715,15 @@ def jvp(
0.19900084
"""
check_callable(fun)
return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)

def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
"""Variant of jvp() that takes an lu.WrappedFun."""
if (not isinstance(primals, (tuple, list)) or
not isinstance(tangents, (tuple, list))):
raise TypeError("primal and tangent arguments to jax.jvp must be tuples or lists; "
f"found {type(primals).__name__} and {type(tangents).__name__}.")
return _jvp(lu.wrap_init(fun, debug_info=debug_info("jvp", fun, primals, {})),
primals, tangents, has_aux=has_aux)

def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
"""Variant of jvp() that takes an lu.WrappedFun."""
ps_flat, tree_def = tree_flatten(primals)
ts_flat, tree_def_2 = tree_flatten(tangents)
if tree_def != tree_def_2:
Expand Down Expand Up @@ -1835,7 +1835,7 @@ def linearize(fun: Callable, *primals, has_aux: bool = False
-6.676704
"""
check_callable(fun)
f = lu.wrap_init(fun)
f = lu.wrap_init(fun, debug_info=debug_info("linearize", fun, primals, {}))
primals_flat, in_tree = tree_flatten(primals)
if has_aux:
jaxtree_fun, out_tree = flatten_fun_nokwargs2(f, in_tree)
Expand Down Expand Up @@ -1983,8 +1983,9 @@ def vjp(
raise NotImplementedError("reduce_axes argument to vjp is deprecated")
del reduce_axes
check_callable(fun)
return _vjp(
lu.wrap_init(fun), *primals, has_aux=has_aux)
wrapped_fun = lu.wrap_init(fun,
debug_info=debug_info("vjp", fun, primals, {}))
return _vjp(wrapped_fun, *primals, has_aux=has_aux)

def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
"""Variant of vjp() that takes an lu.WrappedFun."""
Expand Down Expand Up @@ -2049,7 +2050,10 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
raise NotImplementedError("reduce_axes argument to transpose is deprecated")
del reduce_axes
primals_flat, in_tree = tree_flatten(primals)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
flat_fun, out_tree = flatten_fun_nokwargs(
lu.wrap_init(fun,
debug_info=debug_info("linear_transpose", fun, primals, {})),
in_tree)
in_avals = map(shaped_abstractify, primals_flat)
in_dtypes = map(dtypes.dtype, in_avals)

Expand Down
7 changes: 4 additions & 3 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
return tuple(map(_ensure_str, x))

@lu.transformation_with_aux2
def flatten_fun(f, store, in_tree, *args_flat):
def flatten_fun(f: Callable, store: lu.Store,
in_tree: PyTreeDef, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = f(*py_args, **py_kwargs)
ans, out_tree = tree_flatten(ans)
Expand Down Expand Up @@ -587,8 +588,8 @@ def debug_info(
args: Sequence[Any],
kwargs: dict[str, Any],
*,
static_argnums: tuple[int, ...] = (),
static_argnames: tuple[str, ...] = (),
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = (),
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
# TODO(necula): check if we really need this, e.g., to speed up tracing?
sourceinfo: str | None = None,
Expand Down
7 changes: 4 additions & 3 deletions jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,9 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
else:
jaxpr, consts = call_jaxpr, ()
consts_ = tuple(HashableWrapper(c) for c in consts)
partial_checkify = lu.hashable_partial(lu.wrap_init(
checkify_jaxpr_flat_hashable), jaxpr, consts_, enabled_errors, err_tree)
partial_checkify = lu.hashable_partial(
lu.wrap_init(checkify_jaxpr_flat_hashable, debug_info=jaxpr.debug_info),
jaxpr, consts_, enabled_errors, err_tree)
partial_checkify, metadata = _flatten_and_get_error_metadata_thunk(
partial_checkify)

Expand Down Expand Up @@ -746,7 +747,7 @@ def jaxpr_to_checkify_jaxpr(
checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr,
jaxpr.consts, enabled_errors,
err_tree)
fun = lu.wrap_init(checkify_jaxpr_partial)
fun = lu.wrap_init(checkify_jaxpr_partial, debug_info=jaxpr.jaxpr.debug_info)
fun, metadata = _flatten_and_get_error_metadata_thunk(fun)

new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2416,8 +2416,9 @@ def call_impl(f: lu.WrappedFun, *args, **params):
class ClosedCallPrimitive(CallPrimitive):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
jaxpr: ClosedJaxpr = new_params.pop('call_jaxpr')
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts),
debug_info=jaxpr.jaxpr.debug_info)
return [subfun], new_params

closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')
Expand Down
Loading

0 comments on commit 5b9b852

Please sign in to comment.