Skip to content

Commit

Permalink
Remove axis_name from unmapped_aval
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718558713
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Jan 22, 2025
1 parent f6243ff commit 23d360b
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 32 deletions.
3 changes: 1 addition & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2415,8 +2415,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
raise ValueError("`devices` argument to `device_put_replicated must be "
"a non-empty sequence.")
def _device_put_replicated(x):
aval = core.unmapped_aval(len(devices), core.no_axis_name, 0,
core.get_aval(x))
aval = core.unmapped_aval(len(devices), 0, core.get_aval(x))
assert isinstance(aval, ShapedArray)
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
if config.pmap_no_rank_reduction.value:
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ def callback_batching_rule(
new_args = [arg if dim is batching.not_mapped else
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
batched_result_avals = tuple(
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval)
for aval in result_avals)
core.unmapped_aval(axis_size, 0, aval) for aval in result_avals)

# For FFI calls we must update the layouts. We handle the output layouts
# here, but the input layout updates depend on the vmap_method parameter.
Expand Down
17 changes: 8 additions & 9 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2346,11 +2346,11 @@ def mapped_aval(size: AxisSize, axis: int | None,
else:
raise TypeError(f"no mapping handler for {aval} of type {type(aval)}")

def unmapped_aval(size: AxisSize, axis_name, axis: int | None,
def unmapped_aval(size: AxisSize, axis: int | None,
aval: AbstractValue) -> AbstractValue:
_, handler = aval_mapping_handlers.get(type(aval), (None, None))
if handler is not None:
return handler(size, axis_name, axis, aval)
return handler(size, axis, aval)
else:
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")

Expand All @@ -2366,11 +2366,10 @@ def _map_shaped_array(
weak_type=aval.weak_type, sharding=sharding)

def _unmap_shaped_array(
size: int, axis_name: AxisName, axis: int | None, aval: ShapedArray
) -> ShapedArray:
size: int, axis: int | None, aval: ShapedArray) -> ShapedArray:
if axis is None: return aval
elif type(axis) is int:
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis, axis_name))
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis, None))
if config.sharding_in_types.value else None)
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
weak_type=aval.weak_type, sharding=sharding)
Expand All @@ -2383,7 +2382,7 @@ def _map_dshaped_array(
aval.weak_type)

def _unmap_dshaped_array(
size: AxisSize, axis_name: AxisName, axis: int | None, aval: DShapedArray
size: AxisSize, axis: int | None, aval: DShapedArray
) -> DShapedArray:
if axis is None: return aval
elif type(axis) is int:
Expand All @@ -2396,7 +2395,7 @@ def _unmap_dshaped_array(
aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a)
AbstractToken: (lambda _, __, a: a, lambda _, __, a: a)
}

# When a mapped function is given no axis name, we generate a name object based
Expand Down Expand Up @@ -2777,7 +2776,7 @@ def _check_map(ctx_factory, prim, in_avals, params):
raise JaxprTypeError(f"Map primitive {prim} missing 'out_axes' parameter")
out_axes = params["out_axes"]

binder_avals = [unmapped_aval(axis_size, axis_name, in_axis, v.aval)
binder_avals = [unmapped_aval(axis_size, in_axis, v.aval)
if in_axis is not None else v.aval
for v, in_axis in zip(call_jaxpr.invars, in_axes)]
for binder_aval, in_aval in zip(binder_avals, in_avals):
Expand All @@ -2789,7 +2788,7 @@ def _check_map(ctx_factory, prim, in_avals, params):
_check_jaxpr(ctx_factory, call_jaxpr)

mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval)
out_avals = [unmapped_aval(axis_size, out_axis, aval)
if out_axis is not None else aval
for aval, out_axis in zip(mapped_out_avals, out_axes)]
return out_avals, filter_named_axis_effects(call_jaxpr.effects, {axis_name})
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def out_axes_thunk():
assert len(in_axes) == len(arg_cts)
def unmap_zero(zero, in_axis):
return (zero if in_axis is None else
Zero(core.unmapped_aval(params['axis_size'], params['axis_name'], in_axis, zero.aval)))
Zero(core.unmapped_aval(params['axis_size'], in_axis, zero.aval)))
arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else
arg_ct if in_axis is not None else
arg_ct.sum(0)
Expand Down
10 changes: 5 additions & 5 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def __init__(self, a): self.a = a
if isinstance(d, RaggedAxis):
raise NotImplementedError
else:
new_avals.append(core.unmapped_aval(sz, axis_name, d, a)) # type: ignore
new_avals.append(core.unmapped_aval(sz, d, a)) # type: ignore

mentioned = {d for a in new_avals if type(a) is core.DShapedArray
for d in a.shape if type(d) is Name}
Expand Down Expand Up @@ -750,7 +750,7 @@ def _batch_jaxpr2(
handle_ragged(closed_jaxpr.in_avals, dim, aval)
if isinstance(dim, RaggedAxis) else (dim, aval)
for dim, aval in zip(in_axes, closed_jaxpr.in_avals)])
avals_in2 = [core.unmapped_aval(axis_data.size, axis_data.name, b, aval)
avals_in2 = [core.unmapped_aval(axis_data.size, b, aval)
if b is not not_mapped else aval
for aval, b in unsafe_zip(avals_in, in_axes2)]
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2)
Expand Down Expand Up @@ -787,7 +787,7 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest):
f, out_axes = _batch_jaxpr_inner(f, axis_data)
f, out_batched = _match_axes_jaxpr(f, axis_data, out_axes_dest, out_axes)
f = _batch_jaxpr_outer(f, axis_data, in_axes)
avals_in = [core.unmapped_aval(axis_data.size, axis_data.name, b, aval) if b is not not_mapped
avals_in = [core.unmapped_aval(axis_data.size, b, aval) if b is not not_mapped
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
Expand Down Expand Up @@ -906,9 +906,9 @@ def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False)
return x
elif type(src) == type(dst) == int:
aval = core.mapped_aval(sz, src, x.aval)
return Zero(core.unmapped_aval(sz, name, dst, aval))
return Zero(core.unmapped_aval(sz, dst, aval))
elif src is not_mapped and dst is not not_mapped:
return Zero(core.unmapped_aval(sz, name, dst, x.aval))
return Zero(core.unmapped_aval(sz, dst, x.aval))
elif dst is not_mapped and sum_match:
return Zero(core.mapped_aval(sz, src, x.aval))
else:
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def const_out_axes_thunk():
out_axes=tuple(staged_out_axes), call_jaxpr=call_jaxpr)
del staged_params['out_axes_thunk']
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], ax, a)
out_avals = [unmapped_aval(params['axis_size'], ax, a)
for ax, a in zip(staged_out_axes, out_avals_mapped)]
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
for a in out_avals]
Expand Down Expand Up @@ -1956,7 +1956,7 @@ def process_map(self, map_primitive, f, tracers, params):
raise ValueError("Ordered effects not supported for "
f"map primitives: {ordered_effects}")
out_axes = params['out_axes_thunk']()
out_avals = [core.unmapped_aval(axis_size, axis_name, out_axis, a)
out_avals = [core.unmapped_aval(axis_size, out_axis, a)
if out_axis is not None else a
for a, out_axis in zip(reduced_out_avals, out_axes)]
source_info = source_info_util.current()
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def _pmap_unmap_shaped_array(
def _pmap_unmapped_aval(size: core.AxisSize, axis_name, axis: int | None,
aval: core.AbstractValue) -> core.AbstractValue:
if not config.pmap_no_rank_reduction.value:
return core.unmapped_aval(size, axis_name, axis, aval)
return core.unmapped_aval(size, axis, aval)

_, handler = _pmap_aval_mapping_handlers.get(type(aval), (None, None))
if handler is not None:
Expand Down Expand Up @@ -1350,7 +1350,7 @@ def _pmap_partial_eval_custom_params_updater(
return new_params_known, new_params_staged

def _pmap_partial_eval_custom_res_maker(params_known, aval):
return core.unmapped_aval(params_known['axis_size'], core.no_axis_name, 0, aval)
return core.unmapped_aval(params_known['axis_size'], 0, aval)

def _pmap_dce_rule(used_outputs, eqn):
# just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def _stage_jaxpr_abstract_eval(*_, jaxpr):
return jaxpr.out_avals, jaxpr.effects

def _prepend_dim_to_aval(sz, aval):
return core.unmapped_aval(sz, None, 0, aval)
return core.unmapped_aval(sz, 0, aval)

def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll, _split_transpose):
Expand Down Expand Up @@ -704,7 +704,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
extensive_res = _map(trace.new_instantiated_const, extensive_res)
# Create output tracers for jaxpr_unknown bind, adapting extensive shapes.
carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)])
ys_avals = [core.unmapped_aval(length, None, 0, y_aval)
ys_avals = [core.unmapped_aval(length, 0, y_aval)
for y_aval in y_avals]
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in itertools.chain(carry_avals, ys_avals)]
Expand Down Expand Up @@ -1071,7 +1071,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):

# Create residual variables.
intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals)
ext_avals = [core.unmapped_aval(eqn.params['length'], None, 0, a)
ext_avals = [core.unmapped_aval(eqn.params['length'], 0, a)
for a in ext_avals_mapped]
newvar = core.gensym()
intensive_res = _map(newvar, intensive_avals)
Expand Down Expand Up @@ -1149,7 +1149,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
jaxpr.in_avals, [num_consts, num_carry])
carry_avals_jaxpr, y_avals_mapped = split_list(jaxpr.out_avals, [num_carry])
x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals)
y_avals = [core.unmapped_aval(length, None, 0, a)
y_avals = [core.unmapped_aval(length, 0, a)
for a in y_avals_mapped]

if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)):
Expand Down
5 changes: 2 additions & 3 deletions jax/_src/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,8 @@ def __hash__(self):
def _map_ref(size, axis, ref_aval):
return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval))

def _unmap_ref(size, axis_name, axis, ref_aval):
return AbstractRef(core.unmapped_aval(size, axis_name, axis,
ref_aval.inner_aval))
def _unmap_ref(size, axis, ref_aval):
return AbstractRef(core.unmapped_aval(size, axis, ref_aval.inner_aval))

core.aval_mapping_handlers[AbstractRef] = (_map_ref, _unmap_ref)

Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,7 +1613,7 @@ def fun(*res_and_args):
res, args = split_list(res_and_args, [len(jaxpr.constvars)])
res = [_rem_singleton(x) if w else x for x, w in zip(res, which)]
return core.eval_jaxpr(jaxpr, res, *args)
res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
for v, w in zip(jaxpr.constvars, which)]
in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
Expand Down Expand Up @@ -1740,7 +1740,7 @@ def staged(*args):
res_, ins = split_list(args, [len(which)])
res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)]
return core.eval_jaxpr(jaxpr_staged, (), *res, *ins)
res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
for w, v in zip(which_, jaxpr_staged.invars[:len(which)])]
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]]
jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic(staged, avals_in)
Expand Down

0 comments on commit 23d360b

Please sign in to comment.