Skip to content

Commit

Permalink
[pallas] DMA start discharge.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713280422
  • Loading branch information
cperivol authored and Google-ML-Automation committed Jan 8, 2025
1 parent 5511949 commit 8701c81
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 28 deletions.
77 changes: 57 additions & 20 deletions jax/_src/pallas/mosaic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,8 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,

jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn

def dma_start_discharge_rule(in_avals, out_avals,
*args, tree, device_id_type):
def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals,
*args, tree, device_id_type):
(
src_ref,
src_transforms,
Expand All @@ -575,19 +575,42 @@ def dma_start_discharge_rule(in_avals, out_avals,
_,
) = tree_util.tree_unflatten(tree, in_avals)
del out_avals

(
_,
_,
dst_discharge,
_,
dst_sem_discharge,
_,
*maybe_src_sem_discharge,
) = tree_util.tree_unflatten(tree, should_discharge)
is_remote = device_id is not None
src_sem_discharge = None
if is_remote:
src_sem_discharge = maybe_src_sem_discharge[0]

if not is_remote:
# Local async copies only use one semaphore.
assert src_sem is None
assert src_sem_transforms is None

# If there is nothing to do just return.
if not (dst_discharge or dst_sem_discharge or src_sem_discharge):
return (None,) * len(should_discharge), []

num_src_sem_transforms = len(tree_util.tree_leaves(src_sem_transforms_avals))
num_dst_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals))
num_src_transform_vals = len(tree_util.tree_leaves(src_transforms_avals))
num_dst_transform_vals = len(tree_util.tree_leaves(dst_transforms_avals))

updates = state_discharge.transform_array(src_ref, src_transforms)
local_src = updates
# We only transform if we have to. This is because transforms may
# try to slice references that are not to be discharged which is
# invalid.
local_src = updates = None
if is_remote or dst_sem_discharge or dst_discharge:
updates = state_discharge.transform_array(src_ref[...], src_transforms)
local_src = updates

if is_remote:
# Note that this code only works in SPMD mode. If not all devices execute
Expand Down Expand Up @@ -641,21 +664,28 @@ def dma_start_discharge_rule(in_avals, out_avals,
global_dst_transforms,
)

_, new_dst = state_discharge.transform_swap_array(
dst_ref, dst_transforms, updates
)
if dst_discharge:
_, new_dst = state_discharge.transform_swap_array(
dst_ref, dst_transforms, updates
)
else:
new_dst = None

# Update semaphore values.
# TODO(justinfu): Potentially handle asymmetric copy sizes.
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
dst_sem_value = _transform_semaphore(
dst_sem, dst_sem_transforms, dst_sem_aval
)
_, new_dst_sem = state_discharge.transform_swap_array(
dst_sem, dst_sem_transforms, dst_sem_value + recv_size
)
if is_remote:
if dst_sem_discharge:
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
dst_sem_value = _transform_semaphore(
dst_sem, dst_sem_transforms, dst_sem_aval
)
_, new_dst_sem = state_discharge.transform_swap_array(
dst_sem, dst_sem_transforms, dst_sem_value + recv_size
)
else:
new_dst_sem = None

if src_sem_discharge:
send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE)
send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
src_sem_value = _transform_semaphore(
Expand All @@ -681,7 +711,7 @@ def dma_start_discharge_rule(in_avals, out_avals,
len(in_avals)), f"{len(new_vals), new_vals} != {len(in_avals)}"
return new_vals, []

state_discharge.register_discharge_rule(dma_start_p)(dma_start_discharge_rule)
state_discharge.register_partial_discharge_rule(dma_start_p)(dma_start_partial_discharge_rule)


dma_wait_p = jax_core.Primitive('dma_wait')
Expand Down Expand Up @@ -719,8 +749,9 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,

jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn

def dma_wait_discharge_rule(in_avals, out_avals,
*args, tree, device_id_type):
def dma_wait_partial_discharge_rule(should_discharge,
in_avals, out_avals,
*args, tree, device_id_type):
# TODO(b/370563115): perform ref update in dma_wait discharge rule instead of dma_start
del out_avals, device_id_type
_, _, dst_ref, dst_ref_transforms, dst_sem, dst_sem_transforms, _, _, _ = (
Expand All @@ -735,6 +766,12 @@ def dma_wait_discharge_rule(in_avals, out_avals,
src_sem_transforms_avals,
device_id_aval,
) = tree_util.tree_unflatten(tree, in_avals)

# The only one we can discharge is the dst semaphore.
should_discharge_unflattened = tree_util.tree_unflatten(tree, should_discharge)
if not should_discharge_unflattened[4]:
return (None,) * len(in_avals), []

num_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals))
num_transforms = len(tree_util.tree_leaves(dst_ref_transforms_avals))
updates = state_discharge.transform_array(dst_ref, dst_ref_transforms)
Expand All @@ -754,7 +791,7 @@ def dma_wait_discharge_rule(in_avals, out_avals,
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_transforms_avals))
new_vals += (None,) * len(tree_util.tree_leaves(device_id_aval)) # device_id
return new_vals, []
state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule)
state_discharge.register_partial_discharge_rule(dma_wait_p)(dma_wait_partial_discharge_rule)

def _get_ref_and_transforms(ref):
if isinstance(ref, state.TransformedRef):
Expand Down
28 changes: 20 additions & 8 deletions jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,17 +928,29 @@ def _run_scoped_discharge_rule(

@functools.partial(mlir.register_lowering, run_scoped_p)
def _run_scoped_lowering_rule(ctx, *args, jaxpr):
# This lowering rule gets triggered when run_scoped is not discharged.
# In this case there are no stateful effects to handle.
should_discharge = [
isinstance(aval, state.AbstractRef) for aval in ctx.avals_in
]

jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
num_return_values = len(jaxpr_noconst.outvars)
should_discharge = should_discharge + [
isinstance(var.aval, state.AbstractRef) for var in jaxpr.invars
]
discharged_body, new_consts = state_discharge.discharge_state(
jaxpr_noconst, [], should_discharge=should_discharge)
if new_consts:
raise NotImplementedError(
"Cannot handle new consts created by state discharge.")

def _lower_fun(*lower_fun_args):
updates, out = _run_scoped_discharge_rule(
should_discharge,
[], [], *lower_fun_args,
jaxpr=jaxpr)
assert len(updates) == 0, 'Cannot lower run_scoped with effects.'
return out
# Create inputs filled with uninitialized values to the body.
num_consts = len(lower_fun_args)
body_avals = [v.aval for v in discharged_body.invars[num_consts:]]
init_vals = [uninitialized_value(
aval.shape, aval.dtype) for aval in body_avals]
init_vals_with_consts = lower_fun_args + tuple(init_vals)
out = jax_core.eval_jaxpr(discharged_body, [], *init_vals_with_consts)
return out[:num_return_values]

return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args)

0 comments on commit 8701c81

Please sign in to comment.