Skip to content

Commit

Permalink
The code assumes WARP_SIZE to be equal to 32, which is not the case o…
Browse files Browse the repository at this point in the history
…n ROCm (ROCm#406)

Signed-off-by: Gregory Shtrasberg <[email protected]>
  • Loading branch information
gshtras authored Feb 5, 2025
1 parent ed3337d commit f65ecc9
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ __global__ void sgl_moe_align_block_size_kernel(
__shared__ int32_t shared_counts[32][8];
__shared__ int32_t local_offsets[256];

const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / 32;
const int lane_id = threadIdx.x % 32;
const int experts_per_warp = 8;
const int my_expert_start = warp_id * experts_per_warp;

Expand Down

0 comments on commit f65ecc9

Please sign in to comment.