Skip to content

Commit

Permalink
[sharding_in_types] Make vmap work with shard_map + pallas
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718511907
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Jan 22, 2025
1 parent 64e9b07 commit b6b45f8
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 18 deletions.
13 changes: 9 additions & 4 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax._src import effects
from jax._src import compute_on
from jax._src import mesh as mesh_lib
from jax._src.mesh import AxisTypes
from jax._src.partition_spec import PartitionSpec as P
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
Expand Down Expand Up @@ -1687,21 +1688,25 @@ def _invalid_shape_error(shape: Shape, context: str=""):

# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
# Collective too.
def modify_spec_for_hidden(spec, mesh) -> P:
def modify_spec_for_hidden_collective(spec, mesh) -> P:
if all(s is None for s in spec):
return spec
new_spec = [] # type: ignore
for s in spec:
if s is None:
new_spec.append(s)
else:
temp_s = s[0] if isinstance(s, tuple) else s
new_spec.append(
None if mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Hidden else s)
None
if mesh._name_to_type[temp_s] in (AxisTypes.Hidden, AxisTypes.Collective)
else s)
return P(*new_spec)

def _maybe_modify_sharding(sharding):
if mesh_lib.AxisTypes.Hidden not in sharding.mesh.axis_types:
if sharding.mesh._are_all_axes_visible:
return sharding
new_spec = modify_spec_for_hidden(sharding.spec, sharding.mesh)
new_spec = modify_spec_for_hidden_collective(sharding.spec, sharding.mesh)
return sharding.with_spec(new_spec)


Expand Down
16 changes: 10 additions & 6 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5585,7 +5585,7 @@ def _reduce_prod_jvp_rule(primals, tangents, *, axes):

reduce_prod_p = standard_primitive(
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'),
'reduce_prod')
'reduce_prod', sharding_rule=_reduce_op_sharding_rule)
ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule
batching.defreducer(reduce_prod_p, _get_prod_identity)
pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, _reduce_prod,
Expand Down Expand Up @@ -5613,8 +5613,9 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule


reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
'reduce_min')
reduce_min_p = standard_primitive(
_reduce_op_shape_rule, _input_dtype, 'reduce_min',
sharding_rule=_reduce_op_sharding_rule)
ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule)
batching.defreducer(reduce_min_p, _get_min_identity)
pe.padding_rules[reduce_min_p] = partial(_reducer_padding, _reduce_min,
Expand Down Expand Up @@ -5705,22 +5706,25 @@ def _reduce_logical_shape_rule(operand, *, axes):
raise TypeError(f"logical reduction requires operand dtype bool or int, got {operand.dtype}.")
return tuple(np.delete(operand.shape, axes))

def _reduce_logical_sharding_rule(operand, *, axes):
return operand.sharding.with_spec(tuple_delete(operand.sharding.spec, axes))

reduce_or_p = standard_primitive(
_reduce_logical_shape_rule, _input_dtype, 'reduce_or',
weak_type_rule=_strip_weak_type)
weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule)
batching.defreducer(reduce_or_p, _get_bitwise_or_identity)


reduce_and_p = standard_primitive(
_reduce_logical_shape_rule, _input_dtype, 'reduce_and',
weak_type_rule=_strip_weak_type)
weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule)
batching.defreducer(reduce_and_p, _get_bitwise_and_identity)
batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule


reduce_xor_p = standard_primitive(
_reduce_logical_shape_rule, _input_dtype, 'reduce_xor',
weak_type_rule=_strip_weak_type)
weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule)
batching.defreducer(reduce_xor_p, _get_bitwise_or_identity)


Expand Down
3 changes: 2 additions & 1 deletion jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def _get_array_abstraction_level(a): return a.array_abstraction_level
def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
if config.sharding_in_types.value:
if rule is None:
if mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore
cur_mesh = mesh_lib.get_abstract_mesh()
if cur_mesh._are_all_axes_hidden or cur_mesh._are_all_axes_collective: # type: ignore
return None if num_out is None else [None] * num_out
else:
raise ValueError(
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,10 @@ def _are_all_axes_collective(self) -> bool:
def _are_all_axes_hidden(self) -> bool:
return all(t == AxisTypes.Hidden for t in self.axis_types.keys())

@functools.cached_property
def _are_all_axes_visible(self) -> bool:
return all(t == AxisTypes.Visible for t in self.axis_types.keys())

@functools.cached_property
def _any_axis_collective(self) -> bool:
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,8 @@ def write_env(var: jax_core.Var, val):
name_stack=ctx.name_stack + eqn.source_info.name_stack
)
loc = mlir._source_info_to_location(ctx, eqn.primitive, source_info)
with source_info_util.user_context(eqn.source_info.traceback), loc:
with (source_info_util.user_context(eqn.source_info.traceback), loc,
eqn.ctx.manager):
if eqn.primitive in lowering_rules:
if eqn.primitive not in skip_mlir_conversions:
invals = [_ensure_mlir_value(x, v.aval)
Expand Down
13 changes: 9 additions & 4 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,10 +1117,15 @@ def index_rewrite_kernel(*indexer_args):
# assert ragged_axis_length is not None
args = (ragged_axis_length, *args)
assert all(isinstance(aval, jax_core.ShapedArray) for aval in out_avals)
batched_out_avals = tuple(
aval.update(shape=tuple_insert(aval.shape, 0, axis_size))
for aval in out_avals
)

batched_out_avals = []
for aval in out_avals:
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, 0, None))
if config.sharding_in_types.value else None)
shape = tuple_insert(aval.shape, 0, axis_size)
batched_out_avals.append(aval.update(shape=shape, sharding=sharding))
batched_out_avals = tuple(batched_out_avals) # type: ignore

out = pallas_call_p.bind(
*dynamic_grid_args,
*args,
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2840,7 +2840,7 @@ def hidden_axes(fun, *, axes: str | tuple[str, ...] | None = None,
def decorator(*args, **kwargs):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Hidden)
with mesh_lib.set_abstract_mesh(new_mesh):
in_specs = tree_map(lambda a: core.modify_spec_for_hidden(
in_specs = tree_map(lambda a: core.modify_spec_for_hidden_collective(
a.aval.sharding.spec, new_mesh), args)
args = mesh_cast(args, in_specs)
out = fun(*args, **kwargs)
Expand All @@ -2861,7 +2861,7 @@ def decorator(*args, **kwargs):
with mesh_lib.set_abstract_mesh(new_mesh):
args = mesh_cast(args, in_shardings)
out = fun(*args, **kwargs)
out_specs = tree_map(lambda o: core.modify_spec_for_hidden(
out_specs = tree_map(lambda o: core.modify_spec_for_hidden_collective(
o.aval.sharding.spec, mesh_lib.get_abstract_mesh()), out)
return mesh_cast(out, out_specs)
return decorator
Expand Down

0 comments on commit b6b45f8

Please sign in to comment.