Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pallas] Support DMA start partial discharge and run_scoped() does its own partial discharge. #25775

Merged
merged 1 commit into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 64 additions & 26 deletions jax/_src/pallas/mosaic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,8 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,

jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn

def dma_start_discharge_rule(in_avals, out_avals,
*args, tree, device_id_type):
def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals,
*args, tree, device_id_type):
(
src_ref,
src_transforms,
Expand All @@ -575,7 +575,22 @@ def dma_start_discharge_rule(in_avals, out_avals,
_,
) = tree_util.tree_unflatten(tree, in_avals)
del out_avals

(
_,
_,
dst_discharge,
_,
dst_sem_discharge,
_,
*maybe_src_sem_discharge,
) = tree_util.tree_unflatten(tree, should_discharge)
is_remote = device_id is not None
src_sem_discharge = None

if is_remote:
src_sem_discharge = maybe_src_sem_discharge[0]

if not is_remote:
# Local async copies only use one semaphore.
assert src_sem is None
Expand All @@ -586,7 +601,7 @@ def dma_start_discharge_rule(in_avals, out_avals,
num_src_transform_vals = len(tree_util.tree_leaves(src_transforms_avals))
num_dst_transform_vals = len(tree_util.tree_leaves(dst_transforms_avals))

updates = state_discharge.transform_array(src_ref, src_transforms)
updates = state_discharge.transform_array(src_ref[...], src_transforms)
local_src = updates

if is_remote:
Expand Down Expand Up @@ -641,47 +656,61 @@ def dma_start_discharge_rule(in_avals, out_avals,
global_dst_transforms,
)

_, new_dst = state_discharge.transform_swap_array(
dst_ref, dst_transforms, updates
)
def do_discharge_dst(dst_ref=dst_ref):
_, ret = state_discharge.transform_swap_array(
dst_ref, dst_transforms, updates
)
return ret

# Update semaphore values.
# TODO(justinfu): Potentially handle asymmetric copy sizes.
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
dst_sem_value = _transform_semaphore(
dst_sem, dst_sem_transforms, dst_sem_aval
)
_, new_dst_sem = state_discharge.transform_swap_array(
dst_sem, dst_sem_transforms, dst_sem_value + recv_size
)
if is_remote:
def do_discharge_dst_sem(dst_sem=dst_sem):
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
dst_sem_value = _transform_semaphore(
dst_sem, dst_sem_transforms, dst_sem_aval
)
_, ret = state_discharge.transform_swap_array(
dst_sem, dst_sem_transforms, dst_sem_value[...] + recv_size
)
return ret

def do_discharge_src_sem(src_sem=src_sem):
send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE)
send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
src_sem_value = _transform_semaphore(
src_sem, src_sem_transforms, src_sem_aval
)
_, new_src_sem = state_discharge.transform_swap_array(
src_sem, src_sem_transforms, src_sem_value + send_size
_, ret = state_discharge.transform_swap_array(
src_sem, src_sem_transforms, src_sem_value[...] + send_size
)
else:
new_src_sem = None
return ret

new_vals = (None,) # src_val
new_vals += (None,) * num_src_transform_vals
new_vals += (new_dst,) # dst_val
new_vals += (do_discharge_dst() if dst_discharge else None,) # dst_val
new_vals += (None,) * num_dst_transform_vals
new_vals += (new_dst_sem,) # dst_sem
new_vals += (do_discharge_dst_sem() if dst_sem_discharge else None,) # dst_sem
new_vals += (None,) * num_dst_sem_transforms
if is_remote:
new_vals += (new_src_sem,) # src_sem
new_vals += (do_discharge_src_sem() if src_sem_discharge else None,) # src_sem
new_vals += (None,) * num_src_sem_transforms
new_vals += (None,) # device_id
assert (len(new_vals) ==
len(in_avals)), f"{len(new_vals), new_vals} != {len(in_avals)}"

# If we didn't discharge everything we could we should keep writes
# to the references that are left over.
if not dst_discharge:
sp.ref_set(dst_ref, None, do_discharge_dst(dst_ref=dst_ref[...]))
if not dst_sem_discharge:
sp.ref_set(dst_sem, None, do_discharge_dst_sem(dst_sem=dst_sem[...]))
if is_remote and not src_sem_discharge:
sp.ref_set(src_sem, None, do_discharge_src_sem(src_sem=src_sem[...]))

return new_vals, []

state_discharge.register_discharge_rule(dma_start_p)(dma_start_discharge_rule)
state_discharge.register_partial_discharge_rule(dma_start_p)(dma_start_partial_discharge_rule)


dma_wait_p = jax_core.Primitive('dma_wait')
Expand Down Expand Up @@ -719,8 +748,9 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,

jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn

def dma_wait_discharge_rule(in_avals, out_avals,
*args, tree, device_id_type):
def dma_wait_partial_discharge_rule(should_discharge,
in_avals, out_avals,
*args, tree, device_id_type):
# TODO(b/370563115): perform ref update in dma_wait discharge rule instead of dma_start
del out_avals, device_id_type
_, _, dst_ref, dst_ref_transforms, dst_sem, dst_sem_transforms, _, _, _ = (
Expand All @@ -735,6 +765,14 @@ def dma_wait_discharge_rule(in_avals, out_avals,
src_sem_transforms_avals,
device_id_aval,
) = tree_util.tree_unflatten(tree, in_avals)

# The only one we can discharge is the dst semaphore. The provided
# buffers are only specified for their types and not their value so
# it's completely irrelevant for us here if they are discharged.
should_discharge_unflattened = tree_util.tree_unflatten(tree, should_discharge)
if not should_discharge_unflattened[4]:
return (None,) * len(in_avals), []

num_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals))
num_transforms = len(tree_util.tree_leaves(dst_ref_transforms_avals))
updates = state_discharge.transform_array(dst_ref, dst_ref_transforms)
Expand All @@ -754,7 +792,7 @@ def dma_wait_discharge_rule(in_avals, out_avals,
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_transforms_avals))
new_vals += (None,) * len(tree_util.tree_leaves(device_id_aval)) # device_id
return new_vals, []
state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule)
state_discharge.register_partial_discharge_rule(dma_wait_p)(dma_wait_partial_discharge_rule)

def _get_ref_and_transforms(ref):
if isinstance(ref, state.TransformedRef):
Expand Down
25 changes: 14 additions & 11 deletions jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,17 +931,20 @@ def _run_scoped_discharge_rule(

@functools.partial(mlir.register_lowering, run_scoped_p)
def _run_scoped_lowering_rule(ctx, *args, jaxpr):
# This lowering rule gets triggered when run_scoped is not discharged.
# In this case there are no stateful effects to handle.
should_discharge = [
isinstance(aval, state.AbstractRef) for aval in ctx.avals_in
]
jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
num_return_values = len(jaxpr_noconst.outvars)
discharged_body, new_consts = state_discharge.discharge_state(
jaxpr_noconst, [], should_discharge=True)
if new_consts: raise NotImplementedError(
"Cannot handle new consts created by state discharge.")

def _lower_fun(*lower_fun_args):
updates, out = _run_scoped_discharge_rule(
should_discharge,
[], [], *lower_fun_args,
jaxpr=jaxpr)
assert len(updates) == 0, 'Cannot lower run_scoped with effects.'
return out
# Create inputs filled with uninitialized values to the body.
num_consts = len(lower_fun_args)
body_avals = [v.aval for v in discharged_body.invars[num_consts:]]
init_vals = [uninitialized_value(
aval.shape, aval.dtype) for aval in body_avals]
out = jax_core.eval_jaxpr(discharged_body, [], *lower_fun_args, *init_vals)
return out[:num_return_values]

return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args)
Loading