-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support dynamic masks in splash attention #25213
Conversation
data_next=data_next, | ||
mask_next=mask_next, | ||
block_mask=block_mask, | ||
partial_mask_blocks=partial_mask_blocks, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does partial_mask_blocks
get stored in scalar memory? What happens if we exceed the size of scalar memory? More generally, how would this handle e.g., a 100k+ length context where partial_mask_blocks is quite large. (Since we can't do deduplication for dynamic arrays.)
Some possible things that might help:
- Support int8 for partial_mask_blocks since its entries are all 0/1, or maybe even packed bool values.
- Have some sort of sharding support for partial_mask_blocks.
- Probably the most "comprehensive" solution would be to support
ComputableMask
(or perhaps another sibling class) havingjax.tree_util.Partial()
as the callable with the ability to specify sharding information for the arrays stored in the partial object. This would even allow users to implement support for the first bullet point themselves, avoiding the complexity of supporting it in the kernel directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
partial_mask_blocks
is in HBM, scalar memory is tiny on TPUs. Using mask_next
, the right mask block is prefetched into VMEM for each kernel invocation. You still need to fit partial_mask_blocks
in HBM, so sharding and using smaller data types help here.
We're running into a known edge case of mosaic here by using int8/bools, but people are working on it. One workaround we can do for now is to use smaller data types and upcast to int32 later.
Thanks for the comments, let me think some more about sharding and get back to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just wanted to give a gentle ping about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your patience. I’ll add sharding support in a follow-up PR to unblock you for now. Specifically, sharding partial_mask_blocks
and the MaskInfo. Stay tuned!
Thanks a lot! |
ada8002
to
02338e2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
|
||
mask_next = jnp.where( | ||
jnp.logical_or(is_empty_mask, is_full_mask), | ||
0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave TODO/comment explaining choice of 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
.swapaxes(-1, -2) | ||
.reshape(*block_mask_shape, kv_block_size, q_block_size) | ||
.swapaxes(-1, -2) | ||
.astype(np.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this needs to be updated to bool for jax 0.4.39 compatibility? I'm not sure if any other changes might be needed too.
Also, would the existing tests catch this dtype issue if they were rerun today?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rebased. The tests did indeed catch this.
jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py
Show resolved
Hide resolved
|
||
if downcast_smem_data: | ||
block_mask = block_mask.astype(np.int8) # values are in the range [0, 1, 2] | ||
data_next = _downcast(data_next, kv_seq_len if is_dkv else q_seq_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be kv_blocks_count
and q_blocks_count
respectively?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, done.
data_next=data_next, | ||
mask_next=mask_next, | ||
block_mask=block_mask, | ||
partial_mask_blocks=partial_mask_blocks, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just wanted to give a gentle ping about this.
Adding support for dynamic masks in the splash attention kernel.
Currently, splash attention expects a static mask. It's preprocessed, and only the interesting (not fully masked) parts of the mask are passed to the kernel. This change allows users to pass in a jax.Array instead. Since we can’t know the number of partial mask blocks at trace time, the entire mask is materialized in
partial_mask_blocks
.