Skip to content

Commit

Permalink
[better_errors] Add debug info to the Jaxprs formed for AD
Browse files Browse the repository at this point in the history
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
  • Loading branch information
gnecula committed Feb 5, 2025
1 parent 414449e commit 85ca946
Show file tree
Hide file tree
Showing 21 changed files with 415 additions and 142 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ pytype_strict_library(
srcs = ["_src/interpreters/mlir.py"],
deps = [
":ad_util",
":api_util",
":config",
":core",
":dtypes",
Expand Down
10 changes: 7 additions & 3 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,12 +682,13 @@ def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool],
return _transpose_jaxpr(jaxpr, tuple(in_linear), tuple(out_zeros))

@weakref_lru_cache
def _transpose_jaxpr(jaxpr, in_lin, out_zeros):
def _transpose_jaxpr(jaxpr: core.ClosedJaxpr,
in_lin: Sequence[bool],
out_zeros: Sequence[bool]):
in_avals = ([a for a, lin in zip(jaxpr.in_avals, in_lin ) if not lin] +
[a for a, zero in zip(jaxpr.out_avals, out_zeros) if not zero])
cell = lambda: None

@lu.wrap_init
def transposed(*args_flat):
ins_flat, out_cts_flat = split_list(args_flat, [len(in_lin) - sum(in_lin)])

Expand Down Expand Up @@ -715,7 +716,10 @@ def transposed(*args_flat):
in_cts_nz, _ = partition_list(in_zeros, in_cts)
return in_cts_nz

transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
transposed_wrapped = lu.wrap_init(transposed,
debug_info=jaxpr.jaxpr.debug_info)
transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(
transposed_wrapped, in_avals)
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error

Expand Down
22 changes: 13 additions & 9 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ def vmap_f(*args, **kwargs):
"to the positional arguments passed to the function, "
f"but got {len(in_axes)=}, {len(args)=}")
args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable)
f = lu.wrap_init(fun)
f = lu.wrap_init(fun, debug_info=debug_info("vmap", fun, args, kwargs))
flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree)
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
axis_size_ = (axis_size if axis_size is not None else
Expand Down Expand Up @@ -1715,15 +1715,15 @@ def jvp(
0.19900084
"""
check_callable(fun)
return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)

def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
"""Variant of jvp() that takes an lu.WrappedFun."""
if (not isinstance(primals, (tuple, list)) or
not isinstance(tangents, (tuple, list))):
raise TypeError("primal and tangent arguments to jax.jvp must be tuples or lists; "
f"found {type(primals).__name__} and {type(tangents).__name__}.")
return _jvp(lu.wrap_init(fun, debug_info=debug_info("jvp", fun, primals, {})),
primals, tangents, has_aux=has_aux)

def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
"""Variant of jvp() that takes an lu.WrappedFun."""
ps_flat, tree_def = tree_flatten(primals)
ts_flat, tree_def_2 = tree_flatten(tangents)
if tree_def != tree_def_2:
Expand Down Expand Up @@ -1835,7 +1835,7 @@ def linearize(fun: Callable, *primals, has_aux: bool = False
-6.676704
"""
check_callable(fun)
f = lu.wrap_init(fun)
f = lu.wrap_init(fun, debug_info=debug_info("linearize", fun, primals, {}))
primals_flat, in_tree = tree_flatten(primals)
if has_aux:
jaxtree_fun, out_tree = flatten_fun_nokwargs2(f, in_tree)
Expand Down Expand Up @@ -1983,8 +1983,9 @@ def vjp(
raise NotImplementedError("reduce_axes argument to vjp is deprecated")
del reduce_axes
check_callable(fun)
return _vjp(
lu.wrap_init(fun), *primals, has_aux=has_aux)
wrapped_fun = lu.wrap_init(fun,
debug_info=debug_info("vjp", fun, primals, {}))
return _vjp(wrapped_fun, *primals, has_aux=has_aux)

def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
"""Variant of vjp() that takes an lu.WrappedFun."""
Expand Down Expand Up @@ -2049,7 +2050,10 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
raise NotImplementedError("reduce_axes argument to transpose is deprecated")
del reduce_axes
primals_flat, in_tree = tree_flatten(primals)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
flat_fun, out_tree = flatten_fun_nokwargs(
lu.wrap_init(fun,
debug_info=debug_info("linear_transpose", fun, primals, {})),
in_tree)
in_avals = map(shaped_abstractify, primals_flat)
in_dtypes = map(dtypes.dtype, in_avals)

Expand Down
7 changes: 4 additions & 3 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
return tuple(map(_ensure_str, x))

@lu.transformation_with_aux2
def flatten_fun(f, store, in_tree, *args_flat):
def flatten_fun(f: Callable, store: lu.Store,
in_tree: PyTreeDef, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = f(*py_args, **py_kwargs)
ans, out_tree = tree_flatten(ans)
Expand Down Expand Up @@ -587,8 +588,8 @@ def debug_info(
args: Sequence[Any],
kwargs: dict[str, Any],
*,
static_argnums: tuple[int, ...] = (),
static_argnames: tuple[str, ...] = (),
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = (),
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
# TODO(necula): check if we really need this, e.g., to speed up tracing?
sourceinfo: str | None = None,
Expand Down
7 changes: 4 additions & 3 deletions jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,9 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
else:
jaxpr, consts = call_jaxpr, ()
consts_ = tuple(HashableWrapper(c) for c in consts)
partial_checkify = lu.hashable_partial(lu.wrap_init(
checkify_jaxpr_flat_hashable), jaxpr, consts_, enabled_errors, err_tree)
partial_checkify = lu.hashable_partial(
lu.wrap_init(checkify_jaxpr_flat_hashable, debug_info=jaxpr.debug_info),
jaxpr, consts_, enabled_errors, err_tree)
partial_checkify, metadata = _flatten_and_get_error_metadata_thunk(
partial_checkify)

Expand Down Expand Up @@ -746,7 +747,7 @@ def jaxpr_to_checkify_jaxpr(
checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr,
jaxpr.consts, enabled_errors,
err_tree)
fun = lu.wrap_init(checkify_jaxpr_partial)
fun = lu.wrap_init(checkify_jaxpr_partial, debug_info=jaxpr.jaxpr.debug_info)
fun, metadata = _flatten_and_get_error_metadata_thunk(fun)

new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2416,8 +2416,9 @@ def call_impl(f: lu.WrappedFun, *args, **params):
class ClosedCallPrimitive(CallPrimitive):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
jaxpr: ClosedJaxpr = new_params.pop('call_jaxpr')
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts),
debug_info=jaxpr.jaxpr.debug_info)
return [subfun], new_params

closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')
Expand Down
75 changes: 49 additions & 26 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import (
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
_non_static_arg_names, prepend_static_args)
_non_static_arg_names, prepend_static_args, debug_info)
from jax._src.errors import UnexpectedTracerError
from jax._src.state.types import AbstractRef
from jax._src.interpreters import ad
Expand All @@ -44,7 +44,7 @@
from jax._src.tree_util import (
tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple,
register_pytree_node_class, tree_leaves, tree_flatten_with_path,
tree_leaves_with_path, keystr, treedef_children)
tree_leaves_with_path, keystr, treedef_children, PyTreeDef)
from jax._src.util import (cache, safe_zip, safe_map, split_list, unzip2,
weakref_lru_cache)

Expand Down Expand Up @@ -78,7 +78,9 @@ def _zeros_like_pytree(x):

# like the api_util.py function, but also grabs output avals for error checking
@lu.transformation_with_aux2
def _flatten_fun_nokwargs(f, store, in_tree, *args_flat):
def _flatten_fun_nokwargs(f: Callable,
store: lu.Store, in_tree: PyTreeDef,
*args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = f(*py_args)
ans_flat, ans_tree = tree_flatten(ans)
Expand Down Expand Up @@ -204,7 +206,7 @@ def defjvps(self, *jvps: Callable[..., ReturnValue] | None) -> None:
*jvps: a sequence of functions, one for each positional argument of the
:class:`~jax.custom_jvp` function. Each function takes as arguments
the tangent value for the corresponding primal input, the primal
output, and the ßprimal inputs. See the example below.
output, and the primal inputs. See the example below.
Returns:
None.
Expand Down Expand Up @@ -239,28 +241,41 @@ def jvp(primals, tangents):

@traceback_util.api_boundary
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
primal_name = getattr(self.fun, '__name__', str(self.fun))
debug = debug_info("custom_jvp fun", self.fun, args, kwargs,
static_argnums=self.nondiff_argnums)
primal_name = debug.func_name
if not self.jvp:
msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp."
raise AttributeError(msg)
jvp_name = getattr(self.jvp, '__name__', str(self.jvp))

args = resolve_kwargs(self.fun, args, kwargs)
if self.nondiff_argnums:
nondiff_argnums = set(self.nondiff_argnums)
args = tuple(_stop_gradient(x) if i in nondiff_argnums else x
for i, x in enumerate(args))
diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args,
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun, debug_info=debug),
diff_argnums, args,
require_static_args_hashable=False)
static_args = [args[i] for i in self.nondiff_argnums]
jvp = prepend_static_args(lu.wrap_init(self.jvp), static_args)
diff_args = [args[i] for i, a in enumerate(args) if i not in self.nondiff_argnums]
debug_jvp = debug_info("custom_jvp jvp", self.jvp,
(*static_args, diff_args, diff_args),
{},
static_argnums=self.nondiff_argnums)
jvp = prepend_static_args(lu.wrap_init(self.jvp,
debug_info=debug_jvp), static_args)
else:
f_, dyn_args = lu.wrap_init(self.fun), args
jvp = lu.wrap_init(self.jvp)
f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug), args
debug_jvp = debug_info("custom_jvp jvp", self.jvp,
(args, args),
{},
static_argnums=self.nondiff_argnums)
jvp = lu.wrap_init(self.jvp, debug_info=debug_jvp)
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_type1 = _flatten_fun_nokwargs(f_, in_tree)
flat_jvp, out_type2 = _flatten_jvp(jvp, primal_name, jvp_name, in_tree,
out_type1)
flat_jvp, out_type2 = _flatten_jvp(jvp, primal_name, debug_jvp.func_name,
in_tree, out_type1)
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat,
symbolic_zeros=self.symbolic_zeros)
_, (out_tree, _) = lu.merge_linear_aux(out_type1, out_type2)
Expand Down Expand Up @@ -611,12 +626,14 @@ def defvjp(self,

@traceback_util.api_boundary
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
primal_name = getattr(self.fun, '__name__', str(self.fun))
debug_fun = debug_info("custom_vjp fun", self.fun, args, kwargs,
static_argnums=self.nondiff_argnums)
if not self.fwd or not self.bwd:
msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp."
msg = f"No VJP defined for custom_vjp function {debug_fun.func_name} using defvjp."
raise AttributeError(msg)
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
args = resolve_kwargs(self.fun, args, kwargs)
debug_fwd = debug_info("custom_vjp fwd", self.fwd, args, kwargs,
static_argnums=self.nondiff_argnums)
if self.optimize_remat:
fwd = optimize_remat_of_custom_vjp_fwd(
self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums,
Expand All @@ -633,23 +650,29 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
for i in self.nondiff_argnums: _check_for_tracers(args[i])
nondiff_argnums = set(self.nondiff_argnums)
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums,
args, require_static_args_hashable=False)
f_, dyn_args = argnums_partial(
lu.wrap_init(self.fun, debug_info=debug_fun), dyn_argnums,
args, require_static_args_hashable=False)
static_args = [args[i] for i in self.nondiff_argnums]
fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args,
require_static_args_hashable=False)
bwd = prepend_static_args(lu.wrap_init(self.bwd), static_args)
fwd_, _ = argnums_partial(lu.wrap_init(fwd, debug_info=debug_fwd),
dyn_argnums, args,
require_static_args_hashable=False)
# TODO(necula): can't construct yet the debug_bwd
bwd = prepend_static_args(lu.wrap_init(self.bwd, debug_info=debug_fwd),
static_args)
else:
f_, dyn_args = lu.wrap_init(self.fun), args
fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug_fun), args
fwd_ = lu.wrap_init(fwd, debug_info=debug_fwd)
# TODO(necula): can't construct yet the debug_bwd
bwd = lu.wrap_init(self.bwd, debug_info=debug_fwd)
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.get_aval(x) for x in args_flat]
if config.mutable_array_checks.value:
f_ = _check_primal_refs(f_, self.nondiff_argnums)
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(
fwd_, self.nondiff_argnums, self.symbolic_zeros, primal_name,
fwd_name, in_tree, out_type)
fwd_, self.nondiff_argnums, self.symbolic_zeros, debug_fun.func_name,
debug_fwd.func_name, in_tree, out_type)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees,
Expand All @@ -658,13 +681,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
return tree_unflatten(out_tree, out_flat)

@lu.transformation2
def _check_primal_refs(f, nondiff_argnums, *args):
def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int], *args):
_check_for_aliased_refs(f, nondiff_argnums, args)
out = f(*args)
_check_for_returned_refs(f, out, 'primal')
return out

def _check_for_aliased_refs(f, nondiff_argnums, args):
def _check_for_aliased_refs(f: Callable, nondiff_argnums: Sequence[int], args):
leaves = tree_leaves(args)
refs: dict[int, int] = {}
for i, x in enumerate(leaves):
Expand Down
19 changes: 11 additions & 8 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
return jvpfun(fun, instantiate, transform_stack), aux

@lu.transformation2
def jvpfun(f, instantiate, transform_stack, primals, tangents):
def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents):
tag = core.TraceTag()
tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
and dtype(t) == float0 else t for t in tangents]
Expand Down Expand Up @@ -106,7 +106,7 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
return tuple(consts) + tuple(out_primals)

@lu.transformation2
def jvp_subtrace(f, tag, primals, tangents):
def jvp_subtrace(f: Callable, tag: core.TraceTag, primals, tangents):
with core.take_current_trace() as parent_trace:
trace = JVPTrace(parent_trace, tag)
in_tracers = [maybe_jvp_tracer(trace, x, t)
Expand Down Expand Up @@ -778,7 +778,8 @@ def make_zero(aval):
out_nz_tracers = [trace.to_jaxpr_tracer(r)
for (r, nz) in zip(out_tangents, out_nzs) if nz]
in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz]
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers)
# TODO(necula): pass debug_info here
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, None)

def linearized(residuals, *tangents):
nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz]
Expand Down Expand Up @@ -932,13 +933,15 @@ def traceable(f, store, in_tree, *primals_and_tangents):
return out_flat


def call_transpose(primitive, params, call_jaxpr, args, ct, _):
def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _):
if isinstance(call_jaxpr, core.ClosedJaxpr):
call_jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
else:
consts = ()
all_args, in_tree_def = tree_flatten((consts, args, ct))
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, False)
fun = lu.hashable_partial(lu.wrap_init(backward_pass,
debug_info=call_jaxpr.debug_info),
call_jaxpr, False)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
update_params = call_transpose_param_updaters.get(primitive)
if update_params:
Expand All @@ -950,7 +953,7 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _):
res_invars, _ = partition_list(which_lin, call_jaxpr.invars)
new_invars = [*res_invars, *call_jaxpr.outvars]
dbidx_map = {v: core.DBIdx(i) for i, v in enumerate(new_invars)}
in_type = [(v.aval.update(shape=tuple(dbidx_map.get(d, d) for d in v.aval.shape))
in_type = [(v.aval.update(shape=tuple(dbidx_map.get(d, d) for d in v.aval.shape)) # type: ignore[arg-type]
if type(v.aval) is core.DShapedArray else v.aval, True) for v in new_invars]
fun = lu.annotate(fun, tuple(in_type))
out_flat = primitive.bind(fun, *all_args, **params)
Expand Down Expand Up @@ -1027,8 +1030,8 @@ def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
def _jvp_jaxpr(jaxpr: core.ClosedJaxpr,
nonzeros: Sequence[bool], instantiate: Sequence[bool]):
assert len(jaxpr.in_avals) == len(nonzeros)
debug_info = jaxpr.jaxpr.debug_info
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=debug_info)
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr),
debug_info=jaxpr.jaxpr.debug_info)
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
nonzeros)
tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import numpy as np

from jax._src import ad_util
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dtypes
Expand Down Expand Up @@ -2156,7 +2157,8 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
as `avals_out`."""
def f_lowered(ctx: LoweringRuleContext, *args, **params):
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
wrapped_fun = lu.wrap_init(f, params)
wrapped_fun = lu.wrap_init(f, params,
debug_info=api_util.debug_info("lower_fun", fun, args, params))
manager = (contextlib.nullcontext() if ctx.jaxpr_eqn_ctx is None else
ctx.jaxpr_eqn_ctx.manager)

Expand Down
Loading

0 comments on commit 85ca946

Please sign in to comment.