Skip to content

Commit

Permalink
Merge pull request #26078 from gnecula:debug_info_jaxpr
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723151082
  • Loading branch information
Google-ML-Automation committed Feb 4, 2025
2 parents b1b88a3 + d12aead commit 414449e
Show file tree
Hide file tree
Showing 22 changed files with 272 additions and 282 deletions.
7 changes: 4 additions & 3 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,9 @@ def _trace_to_jaxpr(fun: Callable,
in_avals: Sequence[core.AbstractValue],
debug: core.DebugInfo
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun, debug_info=debug), in_tree)
try:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
except core.ConcretizationTypeError as e:
msg, = e.args
if 'for checkpoint' in msg:
Expand Down Expand Up @@ -699,7 +699,8 @@ def transposed(*args_flat):
assert next(ins_iter, None) is None
with source_info_util.extend_name_stack('rematted_computation'):
lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr)), in_pvals, False)
lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info),
in_pvals, False)

# Transpose the linear jaxpr (which only has linear inputs).
out_cts_iter = iter(out_cts_flat)
Expand Down
8 changes: 3 additions & 5 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, debug_info,
result_paths, flat_out_axes)
flat_out_axes)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -1430,7 +1430,8 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
"pmap", fun, args, kwargs,
static_argnums=static_broadcasted_tuple)

f = lu.wrap_init(fun)
f = lu.wrap_init(fun, debug_info=dbg)
del dbg
if static_broadcasted_tuple:
if max(static_broadcasted_tuple) >= len(args):
raise ValueError(
Expand Down Expand Up @@ -1477,9 +1478,6 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
raise ValueError(msg) from None
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")

f, res_paths = result_paths(f)
dbg = dbg.add_result_paths(res_paths)
f = lu.add_debug_info(f, dbg)
f, out_axes_thunk = flat_out_axes(f, out_axes)
flat_fun, out_tree = flatten_fun(f, in_tree)

Expand Down
31 changes: 4 additions & 27 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def debug_info(
static_argnums: tuple[int, ...] = (),
static_argnames: tuple[str, ...] = (),
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
# TODO(necula): check if we really need this, e.g., to speed up tracing.
# TODO(necula): check if we really need this, e.g., to speed up tracing?
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> core.DebugInfo:
Expand Down Expand Up @@ -674,29 +674,6 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
arg_names = args_arg_names + kwargs_arg_names
return arg_names

@lu.transformation_with_aux2
def result_paths(_fun, _store, *args, **kwargs):
"linear_util transform to get output pytree paths of pre-flattened function."
ans = _fun(*args, **kwargs)
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
return ans

# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
debug: core.DebugInfo | None,
result_paths: tuple[str, ...] | None = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if debug is None:
return jaxpr
# TODO(necula): re-enable this safety check
# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
if result_paths is not None:
debug = debug._replace(result_paths=tuple(result_paths))
else:
debug = debug.resolve_result_paths()
return jaxpr.replace(debug_info=debug)

def hoist_obj_attrs(f, flat_args):
idxs, objs, flat_args_ = [], [], []
for i, x in enumerate(flat_args):
Expand All @@ -721,7 +698,7 @@ def register_class_with_attrs(t: type) -> None:
_class_with_attrs: set[type] = set()

# TODO(mattjj): make this function faster
def _check_no_aliased_ref_args(dbg, avals, args):
def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
assert config.mutable_array_checks.value
refs: dict[int, int] = {}
for i, (a, x) in enumerate(zip(avals, args)):
Expand All @@ -735,7 +712,7 @@ def _check_no_aliased_ref_args(dbg, avals, args):
if dbg else
f"at both flat index {dup_idx} and flat index {i}") from None

def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo | None, consts, args) -> None:
assert config.mutable_array_checks.value
refs: set[int] = {id(core.get_referent(c)) for c in consts
if isinstance(core.get_aval(c), AbstractRef)}
Expand All @@ -746,4 +723,4 @@ def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
f"array reference of type {a.str_short()} was both closed over and "
f"passed as the argument "
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")
f"{dbg.safe_arg_names(len(args))[i]}" if dbg else "at flat index {i}")
2 changes: 1 addition & 1 deletion jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,7 @@ def checked_fun(*args, **kwargs):
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
debug_info=debug),
in_tree)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, ())
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
# checkify:
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)
Expand Down
10 changes: 6 additions & 4 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2398,7 +2398,8 @@ def bind_with_trace(self, trace, fun_and_args, params):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info),
jaxpr, ())
if config.dynamic_shapes.value:
subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr))
return [subfun], new_params
Expand Down Expand Up @@ -2434,7 +2435,7 @@ def bind(self, *args, **params):
return self._true_bind(*args, **params)

def bind_with_trace(self, trace, fun_and_args, params):
fun = fun_and_args[0]
fun: lu.WrappedFun = fun_and_args[0]
args = fun_and_args[1:]
assert len(params['in_axes']) == len(args)
return trace.process_map(self, fun, args, params)
Expand All @@ -2444,8 +2445,9 @@ def process(self, trace, fun, tracers, params):

def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
jaxpr: Jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr,
debug_info=jaxpr.debug_info), jaxpr, ())
axes = new_params.pop('out_axes')
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
return [subfun], new_params
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __call__(self, *args, **kwargs):
lu.wrap_init(self.fun, debug_info=debug),
in_tree)
in_avals = [core.get_aval(x) for x in args_flat]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
in_tree = treedef_tuple((tree_structure(consts), in_tree))
assert self.vmap_rule is not None
Expand Down
11 changes: 6 additions & 5 deletions jax/_src/custom_dce.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ def __call__(self, *args, **kwargs):
)
static_args = [args[i] for i in self.static_argnums]
dce_rule = api_util.prepend_static_args(
lu.wrap_init(self.dce_rule), static_args
lu.wrap_init(self.dce_rule, debug_info=debug_rule), static_args
)
else:
fun = lu.wrap_init(self.fun, debug_info=debug)
dce_rule = lu.wrap_init(self.dce_rule)
dce_rule = lu.wrap_init(self.dce_rule, debug_info=debug_rule)
dyn_args = args

args_flat, in_tree = tree_util.tree_flatten(dyn_args)
Expand All @@ -176,7 +176,7 @@ def dce_jaxpr_thunk(
)
assert self.dce_rule is not None
dce_jaxpr, _, dce_consts, () = pe.trace_to_jaxpr_dynamic(
flat_rule, in_avals, debug_rule
flat_rule, in_avals
)

# This second round of DCE is used to work out which inputs are actually
Expand All @@ -191,7 +191,7 @@ def dce_jaxpr_thunk(

return core.ClosedJaxpr(dce_jaxpr, dce_consts), used_ins

jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
closed_call = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
out_avals = closed_call.out_avals
out_flat = custom_dce_p.bind(
Expand Down Expand Up @@ -366,7 +366,8 @@ def custom_dce_jvp(primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, **_):
# that most users of this API would compose this with a custom_jvp or
# custom_vjp, which makes this less urgent.
out = core.call_p.bind(
lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr)), *primals, *tangents
lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr),
debug_info=jvp_jaxpr.jaxpr.debug_info), *primals, *tangents
)

out_primals, out_tangents = util.split_list(out, [len(out_nz)])
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,13 +485,13 @@ def __call__(self, *args, **kwargs):
_check_for_tracers(static_args)
else:
static_args = []
f_, dyn_args = lu.wrap_init(self.fun), args
f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug), args
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
in_avals = [core.get_aval(x) for x in args_flat]
mesh = mesh_lib.thread_resources.env.physical_mesh
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
assert not len(consts)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())

Expand Down
15 changes: 9 additions & 6 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
nzs_out = tuple(type(t) is not Zero for t in out_tangents)
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment]
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, None)
residual_avals = map(get_aval, consts)
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
Expand Down Expand Up @@ -167,16 +167,17 @@ def new_arg(trace, primal_aval, nz):
out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans))
del lin_trace, ans, tracers, new_arg

debug_info = jaxpr.jaxpr.debug_info
nzs_out = [type(t) is not Zero for t in out_tangents]
out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t)
for (nz, t) in zip(nzs_out, out_tangents) if nz)
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info)
tangent_trace.invalidate()
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
residuals_and_primals = (*tangent_consts, *out_primals)
residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) # type: ignore[assignment]
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals)
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info)
primal_trace.invalidate()
num_residuals = len(tangent_consts)
tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr))
Expand Down Expand Up @@ -209,7 +210,7 @@ def direct_linearize(traceable: lu.WrappedFun,
out_nzs = [type(t) is not Zero for t in out_tangents]
out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz]
out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) # type: ignore
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info)
tangent_trace.invalidate()
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else
pe.PartialVal.known(zeros_like_aval(t.aval))
Expand Down Expand Up @@ -1026,12 +1027,14 @@ def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
def _jvp_jaxpr(jaxpr: core.ClosedJaxpr,
nonzeros: Sequence[bool], instantiate: Sequence[bool]):
assert len(jaxpr.in_avals) == len(nonzeros)
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
debug_info = jaxpr.jaxpr.debug_info
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=debug_info)
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
nonzeros)
tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(
f_jvp, avals_in)
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()

@lu.transformation_with_aux2
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,8 @@ def _batch_jaxpr2(
axis_data,
in_axes: tuple[int | NotMapped | RaggedAxis, ...],
) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]:
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr),
debug_info=closed_jaxpr.jaxpr.debug_info)
f, out_axes = _batch_jaxpr_inner(f, axis_data)
f = _batch_jaxpr_outer(f, axis_data, in_axes)
in_axes2, avals_in = unzip2([
Expand Down
1 change: 0 additions & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,6 @@ def lower_jaxpr_to_fun(
MLIR func op
"""
util.test_event("lower_jaxpr_to_fun", name)

# The first dimension variable may be the platform index
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
Expand Down
Loading

0 comments on commit 414449e

Please sign in to comment.