Skip to content

Commit

Permalink
[better_errors] Continue adding debug info to Jaxprs (step 6)
Browse files Browse the repository at this point in the history
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).
  • Loading branch information
gnecula committed Feb 8, 2025
1 parent fd1b7cc commit a8d60e3
Show file tree
Hide file tree
Showing 17 changed files with 129 additions and 74 deletions.
9 changes: 5 additions & 4 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from jax._src.state.types import AbstractRef
from jax._src.tree_util import (
PyTreeDef, tree_flatten, tree_unflatten, tree_map,
treedef_children, generate_key_paths, keystr, broadcast_prefix,
treedef_children, generate_key_paths, broadcast_prefix,
prefix_errors)
from jax._src.tree_util import _replace_nones
from jax._src import linear_util as lu
Expand Down Expand Up @@ -664,12 +664,13 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
except (ValueError, TypeError):
pass
else:
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
return tuple(f'{name}{lu._clean_keystr_arg_names(path)}'
for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
args_arg_names = tuple(f'args{keystr(path)}'
args_arg_names = tuple(f'args{lu._clean_keystr_arg_names(path)}'
for path, l in generate_key_paths(args_)
if l is not static)
kwargs_arg_names = tuple(f'kwargs{keystr(path)}'
kwargs_arg_names = tuple(f'kwargs{lu._clean_keystr_arg_names(path)}'
for path, l in generate_key_paths(kwargs_)
if l is not static)
arg_names = args_arg_names + kwargs_arg_names
Expand Down
34 changes: 20 additions & 14 deletions jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def new_body_f(*c_consts_and_vals):
# This checks if the next cond application will error
_ = cond_f(*c_consts, *out)
return out
new_body_f_ = lu.wrap_init(new_body_f)
new_body_f_ = lu.wrap_init(new_body_f, debug_info=body_jaxpr.jaxpr.debug_info)
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
*body_jaxpr.in_avals])
Expand Down Expand Up @@ -952,7 +952,8 @@ def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params):


def shard_map_error_check(
error, enabled_errors, *vals_in, jaxpr, in_names, out_names, **kwargs
error: Error, enabled_errors, *vals_in,
jaxpr: core.Jaxpr, in_names, out_names, **kwargs
):
if (mesh := kwargs.get('mesh')) is None:
raise ValueError('Mesh must be provided for shard_map with checkify.')
Expand All @@ -976,7 +977,6 @@ def shard_map_error_check(
)
num_out_error_vals = out_tree.num_leaves - len(out_names)

@lu.wrap_init
def expand_errors_leading_dim(*xs):
outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs)
errs, outs = split_list(outs, [num_out_error_vals])
Expand All @@ -985,15 +985,18 @@ def expand_errors_leading_dim(*xs):

with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
expand_errors_leading_dim, checked_jaxpr.in_avals
lu.wrap_init(expand_errors_leading_dim,
debug_info=checked_jaxpr.jaxpr.debug_info),
checked_jaxpr.in_avals
)
checked_jaxpr = core.ClosedJaxpr(jaxpr, consts)

# Update shard_map params to account for extra error values.
# Use fully sharded partitioning for out errors.
new_out_names = (*([{0: mesh.axis_names}] * num_out_error_vals), *out_names)
subfun = lu.hashable_partial(
lu.wrap_init(core.eval_jaxpr), checked_jaxpr.jaxpr, checked_jaxpr.consts
lu.wrap_init(core.eval_jaxpr, debug_info=checked_jaxpr.jaxpr.debug_info),
checked_jaxpr.jaxpr, checked_jaxpr.consts
)
new_params = dict(
jaxpr=checked_jaxpr.jaxpr,
Expand All @@ -1007,8 +1010,10 @@ def expand_errors_leading_dim(*xs):
return tree_unflatten(out_tree, err_and_out)
error_checks[shard_map.shard_map_p] = shard_map_error_check

def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
jvp_jaxpr_thunk, call_jaxpr, **params):
def custom_jvp_call_rule(in_err: Error,
enabled_errors: set, *in_vals, num_consts,
jvp_jaxpr_fun: lu.WrappedFun,
call_jaxpr: core.ClosedJaxpr, **params):
# The types to have in mind are:
# jvp : (a -> b) -> (a, T a) -> (b, T b)
# checkify : (a -> b) -> a -> Err b
Expand All @@ -1021,10 +1026,11 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
err_vals, err_tree = jtu.tree_flatten(in_err)
partial_checkify = lu.wrap_init(
functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr,
call_jaxpr.consts, enabled_errors, err_tree))
call_jaxpr.consts, enabled_errors, err_tree),
debug_info=call_jaxpr.jaxpr.debug_info)
partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk(
partial_checkify)
jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_thunk)
jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_fun)
jvp, jvp_out_tree = flatten_fun_output(jvp)
all_outs = custom_derivatives.custom_jvp_call_p.bind(
partial_checkify, jvp, *err_vals, *in_vals, **params)
Expand All @@ -1041,17 +1047,17 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,

# Compared to custom_derivatives.lift_jvp, we're handling the extra inputs and
# outputs that checkify adds (just forwarding the error data's primal and
# tangent components). The jaxpr in jvp_jaxpr_thunk doesn't expect those.
# tangent components). The jaxpr in jvp_jaxpr_fun doesn't expect those.
# TODO(mattjj): can we simplify this, or dedup with custom_derivatives.lift_jvp?
# Adding another layer of lu.transformation was tricky, though maybe doable.
def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk):
@lu.wrap_init
def lift_jvp(num_errs: int, num_consts: int,
jvp_jaxpr_fun: lu.WrappedFun) -> lu.WrappedFun:
def jvp(*xs):
n, ragged = divmod(len(xs), 2)
assert not ragged
primals, tangents = xs[num_consts+num_errs:n], xs[n+num_consts+num_errs:]
zeros = [type(t) is SymbolicZero for t in tangents]
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*zeros)
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
Expand All @@ -1063,7 +1069,7 @@ def jvp(*xs):
primal_errs = xs[num_consts:num_consts+num_errs]
tangent_errs = xs[n+num_consts:n+num_consts+num_errs]
return [*primal_errs, *out_primals, *tangent_errs, *out_tangents]
return jvp
return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info)

def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals,
fun_jaxpr: core.ClosedJaxpr,
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info and debug_info.resolve_result_paths()
if debug_info is None:
assert False # DO_NOT_SUBMIT
# TODO(necula): re-enable these safety checks
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
Expand Down
26 changes: 13 additions & 13 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,20 +378,19 @@ def get_bind_params(self, params):
new_params = dict(params)
call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr')
num_consts: int = new_params.pop('num_consts')
jvp_jaxpr_thunk = new_params.pop('jvp_jaxpr_thunk')
jvp_jaxpr_fun = new_params.pop('jvp_jaxpr_fun')
fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr),
debug_info=call_jaxpr.jaxpr.debug_info)
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk, call_jaxpr.jaxpr.debug_info)
jvp = lift_jvp(num_consts, jvp_jaxpr_fun)
return [fun, jvp], new_params

def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable,
debug_info: core.DebugInfo | None) -> lu.WrappedFun:
def lift_jvp(num_consts: int, jvp_jaxpr_fun: lu.WrappedFun) -> lu.WrappedFun:
def jvp(*xs):
n, ragged = divmod(len(xs), 2)
assert not ragged
primals, tangents = xs[num_consts:n], xs[n+num_consts:]
zeros = [type(t) is SymbolicZero for t in tangents]
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*zeros)
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
Expand All @@ -401,26 +400,26 @@ def jvp(*xs):
for p, z in zip(out_primals, out_zeros)]
assert next(nz_out_tangents_, None) is None
return [*out_primals, *out_tangents]
return lu.wrap_init(jvp, debug_info=debug_info)
return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info)

effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect)

custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')

def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_thunk,
def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_fun,
num_consts, symbolic_zeros):
# TODO(mattjj): could do more checking here...
del in_avals, jvp_jaxpr_thunk, num_consts
del in_avals, jvp_jaxpr_fun, num_consts
disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(call_jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `custom_jvp`: {disallowed_effects}')
return call_jaxpr.out_avals, call_jaxpr.effects
core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck

def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_thunk,
def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_fun,
num_consts, symbolic_zeros):
del jvp_jaxpr_thunk, num_consts, symbolic_zeros
del jvp_jaxpr_fun, num_consts, symbolic_zeros
consts = mlir._ir_consts(call_jaxpr.consts)
out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
ctx.name_stack, ctx.tokens_in, consts,
Expand Down Expand Up @@ -452,7 +451,7 @@ def _custom_jvp_call_dce(
return [False] * len(eqn.invars), None

call_jaxpr = eqn.params["call_jaxpr"]
jvp_jaxpr_thunk = eqn.params["jvp_jaxpr_thunk"]
jvp_jaxpr_fun = eqn.params["jvp_jaxpr_fun"]
# We must set instantiate=True because some inputs that are unused by the
# DCE'ed primal might be used in the JVP rule.
dce_call_jaxpr, used_ins = _cached_closed_call_dce_instantiate(
Expand All @@ -461,7 +460,7 @@ def _custom_jvp_call_dce(

@pe._memoize
def dce_jvp_jaxpr_thunk(*in_zeros):
jvp_jaxpr, consts, out_zeros = jvp_jaxpr_thunk(*in_zeros)
jvp_jaxpr, consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*in_zeros)
dce_jvp_jaxpr, _ = pe.dce_jaxpr(jvp_jaxpr, [*used_outs, *used_outs], True)
dce_out_zeros = [v for used, v in zip(used_outs, out_zeros) if used]
return dce_jvp_jaxpr, consts, dce_out_zeros
Expand All @@ -470,7 +469,8 @@ def dce_jvp_jaxpr_thunk(*in_zeros):
new_params = dict(
eqn.params,
call_jaxpr=dce_call_jaxpr,
jvp_jaxpr_thunk=dce_jvp_jaxpr_thunk,
jvp_jaxpr_fun=lu.wrap_init(dce_jvp_jaxpr_thunk,
debug_info=jvp_jaxpr_fun.debug_info)
)
new_eqn = pe.new_jaxpr_eqn(
eqn.invars, outvars, eqn.primitive, new_params, dce_call_jaxpr.effects,
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,9 @@ def f_tangent(*args):

nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
nz_tangents_out = call_primitive.bind_with_trace(
self.tangent_trace, (lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), new_params)
self.tangent_trace, (lu.wrap_init(f_tangent,
debug_info=lin_jaxpr.debug_info),
*residuals, *nz_tangents_in), new_params)
nz_tangents_out_iter = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal)
for nz, primal in zip(nzs_out, primals_out)]
Expand Down
9 changes: 6 additions & 3 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def _closed_call_param_updater(params, _, __):
return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ()))
call_param_updaters[core.closed_call_p] = _closed_call_param_updater

def abstract_eval_fun(fun, *avals, debug_info=None, **params):
def abstract_eval_fun(fun: Callable, *avals, debug_info=None, **params):
_, avals_out, _, () = trace_to_jaxpr_dynamic(
lu.wrap_init(fun, params, debug_info=debug_info), avals)
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
Expand Down Expand Up @@ -1992,7 +1992,9 @@ def process_map(self, map_primitive, f: lu.WrappedFun,
self.frame.add_eqn(eqn)
return out_tracers

def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros):
def process_custom_jvp_call(self, prim, fun: lu.WrappedFun,
jvp: lu.WrappedFun, tracers,
symbolic_zeros: bool):
tracers = map(self.to_jaxpr_tracer, tracers)
in_avals = [t.aval for t in tracers]
in_tangent_avals = [t.to_tangent_aval() for t in in_avals]
Expand All @@ -2014,7 +2016,8 @@ def jvp_jaxpr_thunk(*in_zeros):
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
dict(call_jaxpr=closed_fun_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
jvp_jaxpr_fun=lu.wrap_init(jvp_jaxpr_thunk,
debug_info=jvp.debug_info),
num_consts=len(consts),
symbolic_zeros=symbolic_zeros),
fun_jaxpr.effects,
Expand Down
15 changes: 8 additions & 7 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,9 @@ def arrange_jaxpr_args_for_wrapped(args):
)
# TODO(cperivol): avoid tracing the jaxpr twice. When doing so don't
# forget to manage the effects.
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), avals_for_wrapped_no_refs)
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(wrapped, debug_info=discharged_jaxpr.debug_info),
avals_for_wrapped_no_refs)
all_out = scan_p.bind(*args_for_wrapped,
jaxpr=core.ClosedJaxpr(new_jaxpr, ()),
length=length,
Expand Down Expand Up @@ -1922,9 +1924,9 @@ def new_body(*consts_refs_carry):
carry, refs_out = split_list(carry_refs, [num_carry])
return [*refs_out, *carry]
new_body_jaxpr, _, new_body_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(new_body), [*remaining_body_const_avals, *[a.inner_aval for a
in ref_avals],
*carry_avals])
lu.wrap_init(new_body, debug_info=discharged_body_jaxpr.debug_info),
[*remaining_body_const_avals, *[a.inner_aval for a in ref_avals],
*carry_avals])
if new_body_consts: raise NotImplementedError

# Since some `Ref`s that were previously consts are now carries, we need to
Expand All @@ -1936,9 +1938,8 @@ def new_cond(*consts_refs_carry):
del refs # We don't use them here!
return core.eval_jaxpr(cond_jaxpr, cond_jaxpr_consts, *consts, *carry)
new_cond_jaxpr, _, new_cond_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(new_cond), [*cond_consts_avals,
*[a.inner_aval for a in ref_avals],
*carry_avals])
lu.wrap_init(new_cond, debug_info=cond_jaxpr.debug_info),
[*cond_consts_avals, *[a.inner_aval for a in ref_avals], *carry_avals])
if new_cond_consts: raise NotImplementedError

out = while_p.bind(*cond_consts, *remaining_body_consts, *refs, *carry,
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/lax/control_flow/solves.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def _root_jvp(const_lengths, jaxprs, primals, tangents):
linearize_and_solve = partial(
core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s)
f_at_solution = lambda *params: f(*params, *solution)
_, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
_, rhs = ad.jvp(lu.wrap_init(f_at_solution,
debug_info=jaxprs.f.jaxpr.debug_info)).call_wrapped(
params.f, params_dot.f)
solution_dot = _map(
operator.neg, linearize_and_solve(*solution, *rhs))
Expand Down
13 changes: 11 additions & 2 deletions jax/_src/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,14 @@ def trans1(static_arg, *dynamic_args, **kwargs):

from collections.abc import Callable, Sequence
from functools import partial
import re
from typing import Any, NamedTuple
import weakref

from jax._src import config
from jax._src import core
from jax._src import traceback_util
from jax._src.tree_util import keystr, generate_key_paths
from jax._src.tree_util import keystr, KeyPath, generate_key_paths
from jax._src.util import curry, cache_clearing_funs, HashableFunction


Expand Down Expand Up @@ -165,6 +166,8 @@ def __init__(self, f: Callable,
self.params = params
self.in_type = in_type
self.debug_info = debug_info
if debug_info is None:
assert False # DO_NOT_SUBMIT

@property
def __name__(self):
Expand Down Expand Up @@ -329,10 +332,16 @@ def wrap_init(f: Callable, params=None, *,
return fun


# We replace <flat index 0> by 0
_re_clean_keystr_arg_names = re.compile(r"<flat index ([^>]+)>")
def _clean_keystr_arg_names(k: KeyPath) -> str:
res = keystr(k)
return _re_clean_keystr_arg_names.sub(r"\1", res)

@transformation_with_aux2
def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
ans = _fun(*args, **kwargs)
result_paths = [keystr(path) for path, _ in generate_key_paths(ans)]
result_paths = [_clean_keystr_arg_names(path) for path, _ in generate_key_paths(ans)]
if _store:
# In some instances a lu.WrappedFun is called multiple times, e.g.,
# the bwd function in a custom_vjp
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3026,11 +3026,11 @@ def _custom_jvp_call_lowering_rule(
ctx: LoweringRuleContext,
*args,
call_jaxpr: jax_core.Jaxpr,
jvp_jaxpr_thunk: Callable,
jvp_jaxpr_fun: lu.WrappedFun,
num_consts: int,
symbolic_zeros: bool,
):
del jvp_jaxpr_thunk
del jvp_jaxpr_fun
if symbolic_zeros: raise NotImplementedError
if num_consts: raise NotImplementedError
if call_jaxpr.consts: raise NotImplementedError
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ def _block_map_function(new_idx, *args):

with grid_mapping.trace_env():
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_block_map_function), idx_avals)
lu.wrap_init(_block_map_function,
debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info),
idx_avals)
shape = block_mapping.block_shape
if dim is batching.not_mapped:
new_block_shape = shape
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/state/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _hoist(*consts_args):
return core.eval_jaxpr(jaxpr, all_consts, *args0, *args1)

hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_hoist), in_avals)
lu.wrap_init(_hoist, debug_info=jaxpr.debug_info), in_avals)
assert not consts, "All consts should have been converted to refs"
return hoisted_jaxpr

Expand Down
Loading

0 comments on commit a8d60e3

Please sign in to comment.