Skip to content

Commit

Permalink
[better_errors] Make it explicit that debug_info is not None.
Browse files Browse the repository at this point in the history
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.

For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.

See #26480 for more details.

PiperOrigin-RevId: 726770483
  • Loading branch information
gnecula authored and Google-ML-Automation committed Feb 14, 2025
1 parent 60dcded commit a0812cd
Show file tree
Hide file tree
Showing 14 changed files with 113 additions and 108 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
This package may safely be removed if it is present on your machine; JAX now
uses `libtpu` instead.

* Deprecations
* The internal function `linear_util.wrap_init` and the constructor
`core.Jaxpr` now must take a non-empty `core.DebugInfo` kwarg. For
a limited time, a `DeprecationWarning` is printed if
`jax.extend.linear_util.wrap_init` is used without debugging info.
A downstream effect of this several other internal functions need debug
info. This change does not affect public APIs.
See https://github.com/jax-ml/jax/issues/26480 for more detail.

## jax 0.5.0 (Jan 17, 2025)

As of this release, JAX now uses
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,15 +620,15 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
return None

def save_wrapped_fun_sourceinfo(wrapper: Callable,
wrapped: Callable | core.DebugInfo | None) -> None:
wrapped: Callable | core.DebugInfo) -> None:
# Prefer this to functools.wraps because it does not create a reference to
# the wrapped function.
if isinstance(wrapped, core.DebugInfo):
func_src_info = wrapped.func_src_info
elif callable(wrapped):
func_src_info = fun_sourceinfo(wrapped)
else:
return
assert False, wrapped # Unreachable
setattr(wrapper, "__fun_sourceinfo__", func_src_info)

_fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
Expand Down Expand Up @@ -716,7 +716,7 @@ def register_class_with_attrs(t: type) -> None:
_class_with_attrs: set[type] = set()

# TODO(mattjj): make this function faster
def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
def _check_no_aliased_ref_args(dbg: core.DebugInfo, avals, args):
assert config.mutable_array_checks.value
refs: dict[int, int] = {}
for i, (a, x) in enumerate(zip(avals, args)):
Expand All @@ -730,7 +730,7 @@ def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
if dbg else
f"at both flat index {dup_idx} and flat index {i}") from None

def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo | None, consts, args) -> None:
def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo, consts, args) -> None:
assert config.mutable_array_checks.value
refs: set[int] = {id(core.get_referent(c)) for c in consts
if isinstance(core.get_aval(c), AbstractRef)}
Expand Down
16 changes: 11 additions & 5 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class Jaxpr:
_outvars: list[Atom]
_eqns: list[JaxprEqn]
_effects: Effects
_debug_info: DebugInfo | None
_debug_info: DebugInfo

@property
def constvars(self) -> list[Var]:
Expand All @@ -117,13 +117,17 @@ def effects(self) -> Effects:
return self._effects

@property
def debug_info(self) -> DebugInfo | None:
def debug_info(self) -> DebugInfo:
return self._debug_info

def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects,
debug_info: DebugInfo | None = None):
# We want all calls to pass a DebugInfo object, but for backwards
# compatibility we have to allow calls when the debug_info
# is missing.
debug_info: DebugInfo = None, # type: ignore[annotation-type-mismatch,assignment]
):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
Expand All @@ -134,14 +138,16 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
eqns: list of equations.
effects: set of effects. The effects on a jaxpr are a superset of the
union of the effects for each equation.
debug_info: optional DebugInfo.
debug_info: debugging information.
"""
self._constvars = list(constvars)
self._invars = list(invars)
self._outvars = list(outvars)
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info and debug_info.resolve_result_paths()
# TODO(https://github.com/jax-ml/jax/issues/26480)
debug_info = debug_info or lu._missing_debug_info("core.Jaxpr")
self._debug_info = debug_info.resolve_result_paths()
# TODO(necula): re-enable these safety checks
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
Expand Down
24 changes: 9 additions & 15 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from jax._src.ad_util import (
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import (
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
_non_static_arg_names, prepend_static_args, debug_info)
argnums_partial, flatten_fun_nokwargs, resolve_kwargs,
prepend_static_args, debug_info)
from jax._src.errors import UnexpectedTracerError
from jax._src.state.types import AbstractRef
from jax._src.interpreters import ad
Expand Down Expand Up @@ -686,28 +686,22 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable

@lu.transformation2
def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int],
debug_info: core.DebugInfo | None, *args):
debug_info: core.DebugInfo, *args):
_check_for_aliased_refs(f, nondiff_argnums, debug_info, args)
out = f(*args)
_check_for_returned_refs(f, out, 'primal')
return out

def _check_for_aliased_refs(f: Callable,
nondiff_argnums: Sequence[int],
debug: core.DebugInfo | None,
debug: core.DebugInfo,
args):
leaves = tree_leaves(args)
refs: dict[int, int] = {}
for i, x in enumerate(leaves):
if (isinstance((a := core.get_aval(x)), AbstractRef) and
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
if debug is not None:
arg_names = debug.safe_arg_names(len(leaves))
else:
# TODO(necula): drop this branch
arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
if arg_names is None:
arg_names = [f'flat index {j}' for j in range(len(leaves))]
arg_names = debug.safe_arg_names(len(leaves))
raise ValueError(
"only one reference to a mutable array may be passed as an argument "
f"to a function, but custom_vjp function {f} got the same mutable "
Expand Down Expand Up @@ -763,8 +757,8 @@ def _check_for_tracers(x):
def _flatten_fwd(f: Callable, store: lu.EqualStore,
nondiff_argnums: Sequence[int],
symbolic_zeros: bool,
debug_primal: core.DebugInfo | None,
debug_fwd: core.DebugInfo | None,
debug_primal: core.DebugInfo,
debug_fwd: core.DebugInfo,
in_tree: PyTreeDef, maybe_out_type, *args):
primal_name = debug_primal.func_name if debug_primal else str(f)
fwd_name = debug_fwd.func_name if debug_fwd else "<unknown>"
Expand Down Expand Up @@ -1560,9 +1554,9 @@ def jvp(primals, tangents):
# simpler, but it would be worth revisiting this.
def optimize_remat_of_custom_vjp_fwd(
fun: Callable[..., ReturnValue],
debug_fun: core.DebugInfo | None,
debug_fun: core.DebugInfo,
fwd: Callable[..., tuple[ReturnValue, Any]],
debug_fwd: core.DebugInfo | None,
debug_fwd: core.DebugInfo,
nondiff_argnums: Sequence[int] = (),
symbolic_zeros: bool = False,
) -> Callable[..., tuple[ReturnValue, Any]]:
Expand Down
23 changes: 11 additions & 12 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents):
@lu.transformation_with_aux2
def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
nzs_in: Sequence[bool],
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
*primals, **params):
with core.take_current_trace() as parent_trace:
tangent_trace = pe.DynamicJaxprTrace(debug_info)
Expand Down Expand Up @@ -133,7 +133,7 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents):
return out_primals, out_tangents

def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr:
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
dbg = jaxpr.debug_info._replace(
arg_names=jaxpr.debug_info.arg_names + (None,) * len(jaxpr.constvars))
return core.Jaxpr(constvars=(),
invars=jaxpr.invars + jaxpr.constvars,
Expand Down Expand Up @@ -768,7 +768,7 @@ def linearize_from_jvp(jvp: Callable,
multiple_results: bool,
nonzeros: Sequence[bool],
user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool,
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
primals, params):
current_name_stack = source_info_util.current_name_stack()
with core.take_current_trace() as parent_trace:
Expand Down Expand Up @@ -1100,15 +1100,14 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
new_debug_info = jaxpr.jaxpr.debug_info
if new_debug_info is not None:
new_arg_names = tuple(_perm(primals_in, tangents_in,
jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars))))
new_result_paths = tuple(_perm(primals_out, tangents_out,
jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars))))
new_debug_info = new_debug_info._replace(
arg_names=new_arg_names,
result_paths=new_result_paths,
)
new_arg_names = tuple(_perm(primals_in, tangents_in,
jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars))))
new_result_paths = tuple(_perm(primals_out, tangents_out,
jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars))))
new_debug_info = new_debug_info._replace(
arg_names=new_arg_names,
result_paths=new_result_paths,
)
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
new_invars, new_outvars, jaxpr.jaxpr.eqns,
jaxpr.jaxpr.effects,
Expand Down
Loading

0 comments on commit a0812cd

Please sign in to comment.