Skip to content

Commit

Permalink
Don't computing forwarding information if we're going to inline.
Browse files Browse the repository at this point in the history
Computing forwarding information is pointless because inlining does everything forwarding would do.

PiperOrigin-RevId: 719321640
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Jan 24, 2025
1 parent 617e79f commit c4abed2
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,26 +1804,30 @@ def _pjit_lower(


def pjit_staging_rule(trace, *args, **params):
jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding(
params['jaxpr'], params['out_shardings'], params['out_layouts'])
params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings,
out_layouts=out_layouts)
# If we're inlining, no need to compute forwarding information; the inlined
# computation will in effect forward things.
if (params["inline"] and
all(isinstance(i, UnspecifiedValue) for i in params["in_shardings"]) and
all(isinstance(o, UnspecifiedValue) for o in params["out_shardings"]) and
all(i is None for i in params["in_layouts"]) and
all(o is None for o in params["out_layouts"])):
jaxpr = params["jaxpr"]
if config.dynamic_shapes.value:
# Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic
# shapes are enabled, use eval_jaxpr, which uses the tracing machinery,
# but redundantly performs abstract evaluation again.
with core.set_current_trace(trace):
out_tracers = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args,
return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args,
propagate_source_info=False)
else:
out_tracers = pe.inline_jaxpr_into_trace(
return pe.inline_jaxpr_into_trace(
trace, jaxpr.jaxpr, jaxpr.consts, *args)
elif config.dynamic_shapes.value:

jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding(
params['jaxpr'], params['out_shardings'], params['out_layouts'])
params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings,
out_layouts=out_layouts)
if config.dynamic_shapes.value:
source_info = source_info_util.current()
out_tracers = []
for aval in _out_type(jaxpr):
Expand Down

0 comments on commit c4abed2

Please sign in to comment.