Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[better_errors] Add debug info to more Jaxprs and Wrappedfun (step 1) #26078

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,9 @@ def _trace_to_jaxpr(fun: Callable,
in_avals: Sequence[core.AbstractValue],
debug: core.DebugInfo
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun, debug_info=debug), in_tree)
try:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
except core.ConcretizationTypeError as e:
msg, = e.args
if 'for checkpoint' in msg:
Expand Down Expand Up @@ -699,7 +699,8 @@ def transposed(*args_flat):
assert next(ins_iter, None) is None
with source_info_util.extend_name_stack('rematted_computation'):
lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr)), in_pvals, False)
lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info),
in_pvals, False)

# Transpose the linear jaxpr (which only has linear inputs).
out_cts_iter = iter(out_cts_flat)
Expand Down
8 changes: 3 additions & 5 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, debug_info,
result_paths, flat_out_axes)
flat_out_axes)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -1430,7 +1430,8 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
"pmap", fun, args, kwargs,
static_argnums=static_broadcasted_tuple)

f = lu.wrap_init(fun)
f = lu.wrap_init(fun, debug_info=dbg)
del dbg
if static_broadcasted_tuple:
if max(static_broadcasted_tuple) >= len(args):
raise ValueError(
Expand Down Expand Up @@ -1477,9 +1478,6 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
raise ValueError(msg) from None
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")

f, res_paths = result_paths(f)
dbg = dbg.add_result_paths(res_paths)
f = lu.add_debug_info(f, dbg)
f, out_axes_thunk = flat_out_axes(f, out_axes)
flat_fun, out_tree = flatten_fun(f, in_tree)

Expand Down
31 changes: 4 additions & 27 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def debug_info(
static_argnums: tuple[int, ...] = (),
static_argnames: tuple[str, ...] = (),
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
# TODO(necula): check if we really need this, e.g., to speed up tracing.
# TODO(necula): check if we really need this, e.g., to speed up tracing?
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> core.DebugInfo:
Expand Down Expand Up @@ -674,29 +674,6 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
arg_names = args_arg_names + kwargs_arg_names
return arg_names

@lu.transformation_with_aux2
def result_paths(_fun, _store, *args, **kwargs):
"linear_util transform to get output pytree paths of pre-flattened function."
ans = _fun(*args, **kwargs)
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
return ans

# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
debug: core.DebugInfo | None,
result_paths: tuple[str, ...] | None = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if debug is None:
return jaxpr
# TODO(necula): re-enable this safety check
# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
if result_paths is not None:
debug = debug._replace(result_paths=tuple(result_paths))
else:
debug = debug.resolve_result_paths()
return jaxpr.replace(debug_info=debug)

def hoist_obj_attrs(f, flat_args):
idxs, objs, flat_args_ = [], [], []
for i, x in enumerate(flat_args):
Expand All @@ -721,7 +698,7 @@ def register_class_with_attrs(t: type) -> None:
_class_with_attrs: set[type] = set()

# TODO(mattjj): make this function faster
def _check_no_aliased_ref_args(dbg, avals, args):
def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
assert config.mutable_array_checks.value
refs: dict[int, int] = {}
for i, (a, x) in enumerate(zip(avals, args)):
Expand All @@ -735,7 +712,7 @@ def _check_no_aliased_ref_args(dbg, avals, args):
if dbg else
f"at both flat index {dup_idx} and flat index {i}") from None

def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo | None, consts, args) -> None:
assert config.mutable_array_checks.value
refs: set[int] = {id(core.get_referent(c)) for c in consts
if isinstance(core.get_aval(c), AbstractRef)}
Expand All @@ -746,4 +723,4 @@ def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
f"array reference of type {a.str_short()} was both closed over and "
f"passed as the argument "
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")
f"{dbg.safe_arg_names(len(args))[i]}" if dbg else "at flat index {i}")
2 changes: 1 addition & 1 deletion jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@ def checked_fun(*args, **kwargs):
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
debug_info=debug),
in_tree)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, ())
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
# checkify:
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)
Expand Down
10 changes: 6 additions & 4 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2369,7 +2369,8 @@ def bind_with_trace(self, trace, fun_and_args, params):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info),
jaxpr, ())
if config.dynamic_shapes.value:
subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr))
return [subfun], new_params
Expand Down Expand Up @@ -2402,7 +2403,7 @@ class MapPrimitive(Primitive):
map_primitive = True

def bind_with_trace(self, trace, fun_and_args, params):
fun = fun_and_args[0]
fun: lu.WrappedFun = fun_and_args[0]
args = fun_and_args[1:]
assert len(params['in_axes']) == len(args)
return trace.process_map(self, fun, args, params)
Expand All @@ -2412,8 +2413,9 @@ def process(self, trace, fun, tracers, params):

def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
jaxpr: Jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr,
debug_info=jaxpr.debug_info), jaxpr, ())
axes = new_params.pop('out_axes')
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
return [subfun], new_params
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __call__(self, *args, **kwargs):
lu.wrap_init(self.fun, debug_info=debug),
in_tree)
in_avals = [core.get_aval(x) for x in args_flat]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
in_tree = treedef_tuple((tree_structure(consts), in_tree))
assert self.vmap_rule is not None
Expand Down
11 changes: 6 additions & 5 deletions jax/_src/custom_dce.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ def __call__(self, *args, **kwargs):
)
static_args = [args[i] for i in self.static_argnums]
dce_rule = api_util.prepend_static_args(
lu.wrap_init(self.dce_rule), static_args
lu.wrap_init(self.dce_rule, debug_info=debug_rule), static_args
)
else:
fun = lu.wrap_init(self.fun, debug_info=debug)
dce_rule = lu.wrap_init(self.dce_rule)
dce_rule = lu.wrap_init(self.dce_rule, debug_info=debug_rule)
dyn_args = args

args_flat, in_tree = tree_util.tree_flatten(dyn_args)
Expand All @@ -176,7 +176,7 @@ def dce_jaxpr_thunk(
)
assert self.dce_rule is not None
dce_jaxpr, _, dce_consts, () = pe.trace_to_jaxpr_dynamic(
flat_rule, in_avals, debug_rule
flat_rule, in_avals
)

# This second round of DCE is used to work out which inputs are actually
Expand All @@ -191,7 +191,7 @@ def dce_jaxpr_thunk(

return core.ClosedJaxpr(dce_jaxpr, dce_consts), used_ins

jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
closed_call = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
out_avals = closed_call.out_avals
out_flat = custom_dce_p.bind(
Expand Down Expand Up @@ -366,7 +366,8 @@ def custom_dce_jvp(primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, **_):
# that most users of this API would compose this with a custom_jvp or
# custom_vjp, which makes this less urgent.
out = core.call_p.bind(
lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr)), *primals, *tangents
lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr),
debug_info=jvp_jaxpr.jaxpr.debug_info), *primals, *tangents
)

out_primals, out_tangents = util.split_list(out, [len(out_nz)])
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,13 +485,13 @@ def __call__(self, *args, **kwargs):
_check_for_tracers(static_args)
else:
static_args = []
f_, dyn_args = lu.wrap_init(self.fun), args
f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug), args
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
in_avals = [core.get_aval(x) for x in args_flat]
mesh = mesh_lib.thread_resources.env.physical_mesh
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
assert not len(consts)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())

Expand Down
15 changes: 9 additions & 6 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
nzs_out = tuple(type(t) is not Zero for t in out_tangents)
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment]
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, None)
residual_avals = map(get_aval, consts)
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
Expand Down Expand Up @@ -166,16 +166,17 @@ def new_arg(trace, primal_aval, nz):
out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans))
del lin_trace, ans, tracers, new_arg

debug_info = jaxpr.jaxpr.debug_info
nzs_out = [type(t) is not Zero for t in out_tangents]
out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t)
for (nz, t) in zip(nzs_out, out_tangents) if nz)
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info)
tangent_trace.invalidate()
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
residuals_and_primals = (*tangent_consts, *out_primals)
residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) # type: ignore[assignment]
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals)
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info)
primal_trace.invalidate()
num_residuals = len(tangent_consts)
tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr))
Expand Down Expand Up @@ -207,7 +208,7 @@ def direct_linearize(traceable: lu.WrappedFun,
out_nzs = [type(t) is not Zero for t in out_tangents]
out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz]
out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) # type: ignore
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info)
tangent_trace.invalidate()
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else
pe.PartialVal.known(zeros_like_aval(t.aval))
Expand Down Expand Up @@ -1019,12 +1020,14 @@ def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
def _jvp_jaxpr(jaxpr: core.ClosedJaxpr,
nonzeros: Sequence[bool], instantiate: Sequence[bool]):
assert len(jaxpr.in_avals) == len(nonzeros)
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
debug_info = jaxpr.jaxpr.debug_info
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=debug_info)
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
nonzeros)
tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(
f_jvp, avals_in)
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()

@lu.transformation_with_aux2
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,8 @@ def _batch_jaxpr2(
axis_data,
in_axes: tuple[int | NotMapped | RaggedAxis, ...],
) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]:
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr),
debug_info=closed_jaxpr.jaxpr.debug_info)
f, out_axes = _batch_jaxpr_inner(f, axis_data)
f = _batch_jaxpr_outer(f, axis_data, in_axes)
in_axes2, avals_in = unzip2([
Expand Down
1 change: 0 additions & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,7 +1393,6 @@ def lower_jaxpr_to_fun(
MLIR func op
"""
util.test_event("lower_jaxpr_to_fun", name)

# The first dimension variable may be the platform index
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
Expand Down
Loading
Loading