Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix assertion in ScanLowering for num_ctas>1 #5680

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
7 changes: 6 additions & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,13 @@ class ScanLoweringHelper {
unsigned getNonAxisNumElementsPerThread();
// Return the number of threads per warp along non-axis dims.
unsigned getNonAxisNumThreadsPerWarp();
// Return the number of warps per CTA along non-axis dims.
unsigned getNonAxisNumWarpsPerCTA();
// Return the number of CTAs per CGA along non-axis dims.
unsigned getNonAxisNumCTAsPerCGA();
// Return the flat numbers of threads computing independent scan results.
unsigned getNonAxisNumThreadsPerCTA();
unsigned getNonAxisNumThreadsPerCTA(); // per CTA
unsigned getNonAxisNumThreadsPerCGA(); // per CGA
// Return the number of warps per CTA along axis dim.
unsigned getAxisNumWarps();
// Return the number of warps per CTA along axis dim with unique data.
Expand Down
50 changes: 31 additions & 19 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,25 @@ unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() {
return product<unsigned>(threadsPerWarp);
}

// Return the flat numbers of threads computing independent scan results.
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() {
unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp();
unsigned ScanLoweringHelper::getNonAxisNumWarpsPerCTA() {
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
warpsPerCTA[getAxis()] = 1;
unsigned numParallelWarpsPerCTA = product<unsigned>(warpsPerCTA);
return numParallelThreadsPerWarp * numParallelWarpsPerCTA;
return product<unsigned>(warpsPerCTA);
}

unsigned ScanLoweringHelper::getNonAxisNumCTAsPerCGA() {
auto CTAsPerCGA = getCTAsPerCGA(getEncoding());
CTAsPerCGA[getAxis()] = 1;
return product<unsigned>(CTAsPerCGA);
}

// Return the flat numbers of threads computing independent scan results.
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() {
return getNonAxisNumThreadsPerWarp() * getNonAxisNumWarpsPerCTA();
}
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCGA() {
return getNonAxisNumThreadsPerWarp() * getNonAxisNumWarpsPerCTA() *
getNonAxisNumCTAsPerCGA();
}

unsigned ScanLoweringHelper::getAxisNumWarps() {
Expand All @@ -265,24 +277,25 @@ unsigned ScanLoweringHelper::getAxisNumBlocks() {
auto sizePerThreads = getSizePerThread(getEncoding());
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
auto shape = getShapePerCTA(getEncoding(), getShape());
unsigned axis = getAxis();
return ceil<unsigned>(
getShape()[axis],
shape[axis],
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
}

unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
auto sizePerThreads = getSizePerThread(getEncoding());
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
auto shape = getShapePerCTA(getEncoding(), getShape());
unsigned axis = getAxis();
unsigned numBlocks = 1;
for (unsigned i = 0; i < sizePerThreads.size(); i++) {
if (i == axis)
continue;
numBlocks *=
ceil<unsigned>(getShape()[i], (sizePerThreads[i] * threadsPerWarp[i] *
warpsPerCTA[i]));
numBlocks *= ceil<unsigned>(
shape[i], (sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i]));
}
return numBlocks;
}
Expand All @@ -296,13 +309,11 @@ bool ScanLoweringHelper::isSupported() {
}

unsigned ScanLoweringHelper::getScratchSizeInElems() {
auto mod = scanOp->getParentOfType<ModuleOp>();
unsigned numWarps = TritonGPUDialect::getNumWarps(mod);
unsigned numNonAxisElementsPerWarp =
getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread();
unsigned numElements = numWarps * numNonAxisElementsPerWarp *
getAxisNumBlocks() * getNonAxisNumBlocks();
return numElements;
unsigned parallelElementsPerThread = getNonAxisNumElementsPerThread();
unsigned numParallelLane = getNonAxisNumThreadsPerCGA();
unsigned axisNumWarps = getAxisNumWarpsWithUniqueData();
unsigned numBlocks = getNonAxisNumBlocks() * getAxisNumBlocks();
return parallelElementsPerThread * numParallelLane * axisNumWarps * numBlocks;
}

unsigned ScanLoweringHelper::getScratchSizeInBytes() {
Expand Down Expand Up @@ -611,12 +622,13 @@ unsigned ScanLoweringHelper::getAxisBlockStride() {
auto sizePerThreads = getSizePerThread(getEncoding());
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
auto shape = getShapePerCTA(getEncoding(), getShape());
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= ceil<unsigned int>(getShape()[dim], sizePerThreads[dim] *
threadsPerWarp[dim] *
warpsPerCTA[dim]);
stride *= ceil<unsigned int>(shape[dim], sizePerThreads[dim] *
threadsPerWarp[dim] *
warpsPerCTA[dim]);
}
llvm_unreachable("Axis not found in order");
}
Expand Down
47 changes: 34 additions & 13 deletions lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ static void storeWarpAccumulator(SmallVector<SmallVector<Value>> &srcValues,
auto b = TritonLLVMOpBuilder(loc, rewriter);
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData();
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCGA();
unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
unsigned chunkId = 0;
unsigned elementStride = helper.getAxisElementStride();
Expand Down Expand Up @@ -136,7 +136,7 @@ static void AddPartialReduce(SmallVector<SmallVector<Value>> &srcValues,
Value laneIdAxis, Value parallelLaneId) {
Location loc = helper.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCGA();
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread();
unsigned elementStride = helper.getAxisElementStride();
Expand Down Expand Up @@ -251,7 +251,7 @@ static void AddPartialReduceOneWarp(SmallVector<SmallVector<Value>> &srcValues,
unsigned elementStride = helper.getAxisElementStride();
unsigned threadStride = helper.getAxisThreadStride();
unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCGA();
unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData();
Value maskFirstWarp = b.icmp_eq(warpId, b.i32_val(0));
Value maskFirstLane = b.icmp_eq(laneIdAxis, b.i32_val(0));
Expand Down Expand Up @@ -348,10 +348,10 @@ struct ScanOpConversion
SmallVector<Value> getMultiDimWarpId(ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper,
Value warpId) const;
std::tuple<Value, Value, Value>
std::tuple<Value, Value, Value, Value, Value>
getDelinearizedIds(ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper, Value laneId,
Value warpId) const;
ScanLoweringHelper &helper, Value laneId, Value warpId,
Value cctaId) const;
LogicalResult emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
const TargetInfoBase &targetInfo) const;
Expand All @@ -366,7 +366,6 @@ ScanOpConversion::getMultiDimLaneId(ConversionPatternRewriter &rewriter,
auto srcEncoding = helper.getEncoding();

auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
auto order = triton::gpu::getOrder(srcEncoding);
return delinearize(rewriter, loc, laneId, threadsPerWarp, order);
}
Expand All @@ -379,34 +378,38 @@ ScanOpConversion::getMultiDimWarpId(ConversionPatternRewriter &rewriter,
unsigned axis = helper.getAxis();
auto srcEncoding = helper.getEncoding();

auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
auto warpOrder = triton::gpu::getWarpOrder(srcEncoding);
return delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
}

// Break up the threadId into lane and warp id along the scan dimension and
// compute a flat id for the parallel dimensions.
std::tuple<Value, Value, Value>
std::tuple<Value, Value, Value, Value, Value>
ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper, Value laneId,
Value warpId) const {
Value warpId, Value cctaId) const {
auto loc = helper.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
unsigned axis = helper.getAxis();
auto srcEncoding = helper.getEncoding();

auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(srcEncoding);
auto threadOrder = triton::gpu::getThreadOrder(srcEncoding);
auto warpOrder = triton::gpu::getWarpOrder(srcEncoding);
auto CTAOrder = triton::gpu::getCTAOrder(srcEncoding);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
SmallVector<Value> multiDimCCTAId =
delinearize(rewriter, loc, cctaId, CTAsPerCGA, CTAOrder);

Value laneIdAxis = multiDimLaneId[axis];
Value warpIdAxis = multiDimWarpId[axis];
Value cctaIdAxis = multiDimCCTAId[axis];

multiDimLaneId[axis] = b.i32_val(0);
threadsPerWarp[axis] = 1;
Expand All @@ -416,10 +419,16 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter,
warpsPerCTA[axis] = 1;
Value warpIdParallel =
linearize(rewriter, loc, multiDimWarpId, warpsPerCTA, warpOrder);
multiDimCCTAId[axis] = b.i32_val(0);
CTAsPerCGA[axis] = 1;
Value cctaIdParallel =
linearize(rewriter, loc, multiDimCCTAId, CTAsPerCGA, CTAOrder);

Value flatIdParallel = b.add(
laneIdParallel,
b.mul(warpIdParallel, b.i32_val(helper.getNonAxisNumThreadsPerWarp())));
return std::make_tuple(laneIdAxis, warpIdAxis, flatIdParallel);
return std::make_tuple(laneIdAxis, warpIdAxis, flatIdParallel, cctaIdAxis,
cctaIdParallel);
}

SmallVector<SmallVector<Value>>
Expand Down Expand Up @@ -479,6 +488,7 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
Value warpSize = b.i32_val(iWarpSize);
Value warpId = b.udiv(threadId, warpSize);
Value laneId = b.urem(threadId, warpSize);
Value cctaId = targetInfo.getClusterCTAId(rewriter, loc);

// Clamp the lane ID to just threads with unique data within a warp.
LinearLayout layout =
Expand All @@ -488,8 +498,19 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
laneMask = (layout.getInDimSize(kLane) - 1) & ~laneMask;
laneId = b.and_(laneId, b.i32_val(laneMask));

auto [laneIdAxis, warpIdAxis, flatIdParallel] =
getDelinearizedIds(rewriter, helper, laneId, warpId);
auto [laneIdAxis, warpIdAxis, flatIdParallel, cctaIdAxis, cctaIdParallel] =
getDelinearizedIds(rewriter, helper, laneId, warpId, cctaId);
Comment on lines +501 to +502
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the math should consider cctaId if the scan is not happening cross ctas


// We assume that cctaIdAxis==0 and gather cctaIdParallel into the
// flatIdParallel. This requires numParallelLane to be computed per
// CGA instead of per CTA in storeWarpAccumulator() and AddPartialReduce().
rewriter.create<triton::AssertOp>(
loc, b.icmp_eq(cctaIdAxis, b.i32_val(0)),
StringRef("(cctaIdAxis==0) && Scan axis accross CTAs is not supported."));
Comment on lines +504 to +509
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we definitely don't want to create asserts as those can have a high cost

flatIdParallel = b.add(
flatIdParallel,
b.mul(cctaIdParallel, b.i32_val(helper.getNonAxisNumThreadsPerCTA())));

auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
warpIdAxis = b.urem(warpIdAxis, b.i32_val(axisNumWarps));
auto srcValues =
Expand Down
66 changes: 63 additions & 3 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2689,6 +2689,7 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):


scan_layouts = [
# thread_size=4, num_warps=4, num_ctas=1
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]),
Expand All @@ -2699,7 +2700,55 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 4], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
# thread_size=4, num_warps=1, num_ctas=4: CTASplitNum=[1,1]
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [1, 1], [0, 1], [4, 1], [1, 1], [0, 1]),
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [1, 1], [0, 1], [4, 1], [1, 1], [0, 1]),
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 1], [0, 1], [1, 4], [1, 1], [0, 1]),
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [1, 1], [0, 1], [2, 2], [1, 1], [0, 1]),
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [1, 1], [0, 1], [2, 2], [1, 1], [0, 1]),
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [1, 1], [0, 1], [4, 1], [1, 1], [1, 0]),
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [1, 1], [0, 1], [4, 1], [1, 1], [1, 0]),
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 1], [0, 1], [1, 4], [1, 1], [1, 0]),
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [1, 1], [0, 1], [2, 2], [1, 1], [1, 0]),
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [1, 1], [0, 1], [2, 2], [1, 1], [1, 0]),
BlockedLayout([1, 4], [1, THREADS_PER_WARP // 1], [1, 1], [0, 1], [1, 4], [1, 1], [1, 0]),
# thread_size=4, num_warps=1, num_ctas=4: CTASplitNum=CTAsPerCGA
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [1, 1], [0, 1], [4, 1], [4, 1], [0, 1]),
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [1, 1], [0, 1], [4, 1], [4, 1], [0, 1]),
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 1], [0, 1], [1, 4], [1, 4], [0, 1]),
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [1, 1], [0, 1], [2, 2], [2, 2], [0, 1]),
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [1, 1], [0, 1], [2, 2], [2, 2], [0, 1]),
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [1, 1], [0, 1], [4, 1], [4, 1], [1, 0]),
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [1, 1], [0, 1], [4, 1], [4, 1], [1, 0]),
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 1], [0, 1], [1, 4], [1, 4], [1, 0]),
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [1, 1], [0, 1], [2, 2], [2, 2], [1, 0]),
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [1, 1], [0, 1], [2, 2], [2, 2], [1, 0]),
BlockedLayout([1, 4], [1, THREADS_PER_WARP // 1], [1, 1], [0, 1], [1, 4], [1, 4], [1, 0]),
Comment on lines +2705 to +2727
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need to test all those combinations?

# thread_size=1, num_warps=4, num_ctas=4: CTASplitNum=[1,1]
BlockedLayout([1, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [4, 1], [1, 1], [0, 1]),
BlockedLayout([1, 1], [8, THREADS_PER_WARP // 8], [1, 4], [0, 1], [4, 1], [1, 1], [0, 1]),
BlockedLayout([1, 1], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 4], [1, 1], [0, 1]),
BlockedLayout([1, 1], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [2, 2], [1, 1], [0, 1]),
BlockedLayout([1, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [2, 2], [1, 1], [0, 1]),
BlockedLayout([1, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [4, 1], [1, 1], [1, 0]),
BlockedLayout([1, 1], [8, THREADS_PER_WARP // 8], [1, 4], [0, 1], [4, 1], [1, 1], [1, 0]),
BlockedLayout([1, 1], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 4], [1, 1], [1, 0]),
BlockedLayout([1, 1], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [2, 2], [1, 1], [1, 0]),
BlockedLayout([1, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [2, 2], [1, 1], [1, 0]),
BlockedLayout([1, 1], [1, THREADS_PER_WARP // 1], [1, 4], [0, 1], [1, 4], [1, 1], [1, 0]),
# thread_size=1, num_warps=4, num_ctas=4: CTASplitNum=CTAsPerCGA
BlockedLayout([1, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [4, 1], [4, 1], [0, 1]),
BlockedLayout([1, 1], [8, THREADS_PER_WARP // 8], [1, 4], [0, 1], [4, 1], [4, 1], [0, 1]),
BlockedLayout([1, 1], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 4], [1, 4], [0, 1]),
BlockedLayout([1, 1], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [2, 2], [2, 2], [0, 1]),
BlockedLayout([1, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [2, 2], [2, 2], [0, 1]),
BlockedLayout([1, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [4, 1], [4, 1], [1, 0]),
BlockedLayout([1, 1], [8, THREADS_PER_WARP // 8], [1, 4], [0, 1], [4, 1], [4, 1], [1, 0]),
BlockedLayout([1, 1], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 4], [1, 4], [1, 0]),
BlockedLayout([1, 1], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [2, 2], [2, 2], [1, 0]),
BlockedLayout([1, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [2, 2], [2, 2], [1, 0]),
BlockedLayout([1, 1], [1, THREADS_PER_WARP // 1], [1, 4], [0, 1], [1, 4], [1, 4], [1, 0]),
]


Expand All @@ -2709,6 +2758,16 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
@pytest.mark.parametrize("add_overflow_check", [False, True])
def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_path: pathlib.Path):

num_wraps = int(np.prod(src_layout.warps_per_cta))
num_ctas = int(np.prod(src_layout.ctas_per_cga))
cluster_dims = tuple((src_layout.ctas_per_cga[i] if i < len(src_layout.ctas_per_cga) else 1) for i in range(3))

if num_ctas > 1 and not is_hopper():
return pytest.skip("num_ctas > 1 is only supported after sm90.")

if cluster_dims[axis] > 1:
return pytest.skip(f"scan axis accross CTAs is not supported (cluster_dims[axis]={cluster_dims[axis]}).")

overflow_check = """
%17 = arith.extsi %arg2 : i32 to i64
%18 = arith.extsi %arg3 : i32 to i64
Expand All @@ -2723,7 +2782,7 @@ def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_pa

ir = f"""
#blocked = {src_layout}
module attributes {{"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
module attributes {{"ttg.num-warps" = {num_wraps} : i32, "ttg.num-ctas" = {num_ctas} : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>>
Expand Down Expand Up @@ -2754,7 +2813,8 @@ def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_pa

temp_file = tmp_path / "test_scan_layouts.ttgir"
temp_file.write_text(ir)
kernel = triton.compile(str(temp_file))
kernel = triton.compile(str(temp_file), options=dict(num_warps=num_wraps, num_ctas=num_ctas,
cluster_dims=cluster_dims))

rs = RandomState(17)
x = rs.randint(-100, 100, (M, N)).astype('int32')
Expand Down