Skip to content

Commit

Permalink
Fix import for Windows platforms
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718982994
  • Loading branch information
Rifur13 authored and Google-ML-Automation committed Jan 27, 2025
1 parent 727d036 commit 959579e
Showing 1 changed file with 32 additions and 25 deletions.
57 changes: 32 additions & 25 deletions tests/pallas/gpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@
from jax.experimental.pallas.ops.gpu import layer_norm
from jax.experimental.pallas.ops.gpu import rms_norm
from jax.experimental.pallas.ops.gpu import softmax
BlockSizes = attention.BlockSizes
else:
attention = None
layer_norm = None
rms_norm = None
softmax = None
BlockSizes = None
import jax.numpy as jnp
import numpy as np

BlockSizes = attention.BlockSizes

# TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs.
# pylint: disable=no-value-for-parameter

Expand Down Expand Up @@ -155,9 +155,9 @@ def setUp(self):
num_heads=(1, 2, 8),
head_dim=(32, 64, 128),
block_sizes=(
BlockSizes.get_default(),
BlockSizes(block_q=64,block_k=64),
BlockSizes(block_q=64,block_k=128),
(("block_q", 128), ("block_k", 128)),
(("block_q", 64), ("block_k", 64)),
(("block_q", 64), ("block_k", 128)),
),
causal=(True, False),
use_fwd=(True, False),
Expand Down Expand Up @@ -199,7 +199,7 @@ def impl(q, k, v):
v, _ = jax.vjp(
functools.partial(
attention.mha,
block_sizes=block_sizes,
block_sizes=BlockSizes(**dict(block_sizes)),
causal=causal,
segment_ids=segment_ids,
interpret=self.INTERPRET,
Expand All @@ -213,7 +213,7 @@ def impl(q, k, v):
else:
impl = functools.partial(
attention.mha,
block_sizes=block_sizes,
block_sizes=BlockSizes(**dict(block_sizes)),
causal=causal,
segment_ids=segment_ids,
interpret=self.INTERPRET,
Expand All @@ -228,23 +228,30 @@ def impl(q, k, v):
num_heads=(1, 2),
head_dim=(32, 64, 128,),
block_sizes=(
BlockSizes.get_default(),
BlockSizes(
block_q=128,
block_k=128,
block_q_dkv=64,
block_kv_dkv=64,
block_q_dq=64,
block_kv_dq=64,
),
BlockSizes(
block_q=128,
block_k=128,
block_q_dkv=64,
block_kv_dkv=128,
block_q_dq=128,
block_kv_dq=64,
),
(
("block_q", 128),
("block_k", 128),
("block_q_dkv", 128),
("block_kv_dkv", 128),
("block_q_dq", 128),
("block_kv_dq", 128),
),
(
("block_q", 64),
("block_k", 64),
("block_q_dkv", 64),
("block_kv_dkv", 64),
("block_q_dq", 64),
("block_kv_dq", 64),
),
(
("block_q", 64),
("block_k", 128),
("block_q_dkv", 64),
("block_kv_dkv", 128),
("block_q_dq", 128),
("block_kv_dq", 64),
),
),
causal=(True, False),
use_segment_ids=(True, False),
Expand Down Expand Up @@ -280,7 +287,7 @@ def test_fused_attention_bwd(
def f(q, k, v):
return attention.mha(
q, k, v,
block_sizes=block_sizes,
block_sizes=BlockSizes(**dict(block_sizes)),
causal=causal,
segment_ids=segment_ids,
interpret=self.INTERPRET).sum()
Expand Down

0 comments on commit 959579e

Please sign in to comment.