Skip to content

Commit

Permalink
Fix the signature of paged_attention by marking megacore_mode optional (
Browse files Browse the repository at this point in the history
  • Loading branch information
fenghuizhang authored Jan 23, 2025
1 parent 36dcba3 commit 9a1405d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 32 deletions.
57 changes: 30 additions & 27 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,6 @@ def test_paged_attention_wrapper_with_megacore_modes(self):
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_wrapper_with_dynamo(self):
from torch_xla.experimental.custom_kernel import paged_attention
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention

max_kv_len = 2048
Expand All @@ -877,50 +876,54 @@ def test_paged_attention_wrapper_with_dynamo(self):
page_indices_xla = page_indices.to("xla")

def paged_attention_wrapper(q, k, v, seq_lens, page_indices,
pages_per_compute_block):
pages_per_compute_block, attn_logits_soft_cap):
return torch.ops.xla.paged_attention(
q,
k,
v,
seq_lens,
page_indices,
pages_per_compute_block=pages_per_compute_block,
attn_logits_soft_cap=attn_logits_soft_cap,
)

compiled_paged_attention = torch.compile(
paged_attention_wrapper, backend="openxla")

output = compiled_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
page_indices_xla,
pages_per_compute_block=block_size // page_size,
)

q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32)
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
expected_output = torch.from_numpy(
np.array(
jax_paged_attention(
q_jax,
k_pages_jax,
v_pages_jax,
seq_lens_jax,
page_indices_jax,
pages_per_compute_block=block_size // page_size,
)))

self.assertTrue(
torch.allclose(
output.cpu()[seq_lens > 0],
expected_output.cpu()[seq_lens > 0],
atol=1e-5,
rtol=1e-5))
for attn_logits_soft_cap in (1.0, None):
output = compiled_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
page_indices_xla,
pages_per_compute_block=block_size // page_size,
attn_logits_soft_cap=attn_logits_soft_cap,
)
expected_output = torch.from_numpy(
np.array(
jax_paged_attention(
q_jax,
k_pages_jax,
v_pages_jax,
seq_lens_jax,
page_indices_jax,
pages_per_compute_block=block_size // page_size,
attn_logits_soft_cap=attn_logits_soft_cap,
)))

self.assertTrue(
torch.allclose(
output.cpu()[seq_lens > 0],
expected_output.cpu()[seq_lens > 0],
atol=1e-5,
rtol=1e-5))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
Expand Down
10 changes: 5 additions & 5 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,7 @@ def flash_attention_non_xla(q: torch.Tensor,

XLA_LIB.define(
"paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices,"
" int pages_per_compute_block, str megacore_mode=None, float? attn_logits_soft_cap=None) -> Tensor",
" int pages_per_compute_block, str? megacore_mode=None, float? attn_logits_soft_cap=None) -> Tensor",
)


Expand All @@ -1092,8 +1092,8 @@ def paged_attention_xla(q: torch.Tensor,
lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: str = None,
attn_logits_soft_cap: float = None):
megacore_mode: str | None = None,
attn_logits_soft_cap: float | None = None):
return paged_attention(q, k_pages, v_pages, lengths, page_indices,
pages_per_compute_block, megacore_mode,
attn_logits_soft_cap)
Expand All @@ -1106,8 +1106,8 @@ def paged_attention_non_xla(q: torch.Tensor,
lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: str = None,
attn_logits_soft_cap: float = None):
megacore_mode: str | None = None,
attn_logits_soft_cap: float | None = None):
return non_xla_attetion(q, k_pages, v_pages, "paged")


Expand Down

0 comments on commit 9a1405d

Please sign in to comment.