Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dynamic masks in splash attention #25213

Merged
merged 1 commit into from
Jan 28, 2025

Conversation

Rifur13
Copy link
Collaborator

@Rifur13 Rifur13 commented Dec 2, 2024

Adding support for dynamic masks in the splash attention kernel.

Currently, splash attention expects a static mask. It's preprocessed, and only the interesting (not fully masked) parts of the mask are passed to the kernel. This change allows users to pass in a jax.Array instead. Since we can’t know the number of partial mask blocks at trace time, the entire mask is materialized in partial_mask_blocks.

data_next=data_next,
mask_next=mask_next,
block_mask=block_mask,
partial_mask_blocks=partial_mask_blocks,
Copy link

@apghml apghml Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does partial_mask_blocks get stored in scalar memory? What happens if we exceed the size of scalar memory? More generally, how would this handle e.g., a 100k+ length context where partial_mask_blocks is quite large. (Since we can't do deduplication for dynamic arrays.)

Some possible things that might help:

  • Support int8 for partial_mask_blocks since its entries are all 0/1, or maybe even packed bool values.
  • Have some sort of sharding support for partial_mask_blocks.
  • Probably the most "comprehensive" solution would be to support ComputableMask (or perhaps another sibling class) having jax.tree_util.Partial() as the callable with the ability to specify sharding information for the arrays stored in the partial object. This would even allow users to implement support for the first bullet point themselves, avoiding the complexity of supporting it in the kernel directly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

partial_mask_blocks is in HBM, scalar memory is tiny on TPUs. Using mask_next, the right mask block is prefetched into VMEM for each kernel invocation. You still need to fit partial_mask_blocks in HBM, so sharding and using smaller data types help here.

We're running into a known edge case of mosaic here by using int8/bools, but people are working on it. One workaround we can do for now is to use smaller data types and upcast to int32 later.

Thanks for the comments, let me think some more about sharding and get back to you.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to give a gentle ping about this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patience. I’ll add sharding support in a follow-up PR to unblock you for now. Specifically, sharding partial_mask_blocks and the MaskInfo. Stay tuned!

@apghml
Copy link

apghml commented Dec 2, 2024

Thanks a lot!

Copy link
Collaborator

@sharadmv sharadmv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!


mask_next = jnp.where(
jnp.logical_or(is_empty_mask, is_full_mask),
0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave TODO/comment explaining choice of 0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jan 13, 2025
.swapaxes(-1, -2)
.reshape(*block_mask_shape, kv_block_size, q_block_size)
.swapaxes(-1, -2)
.astype(np.int32)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this needs to be updated to bool for jax 0.4.39 compatibility? I'm not sure if any other changes might be needed too.
Also, would the existing tests catch this dtype issue if they were rerun today?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebased. The tests did indeed catch this.


if downcast_smem_data:
block_mask = block_mask.astype(np.int8) # values are in the range [0, 1, 2]
data_next = _downcast(data_next, kv_seq_len if is_dkv else q_seq_len)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be kv_blocks_count and q_blocks_count respectively?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, done.

data_next=data_next,
mask_next=mask_next,
block_mask=block_mask,
partial_mask_blocks=partial_mask_blocks,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to give a gentle ping about this.

@copybara-service copybara-service bot merged commit bc130c7 into jax-ml:main Jan 28, 2025
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants