Skip to content

Commit

Permalink
[Pipeliner] Fix crash in rewriting TMA descriptor updates (#5843)
Browse files Browse the repository at this point in the history
Lots of our code assumes that `scf.if` has a non-empty else region, but
sometimes it can be empty, which typically happens due to one of the
`scf.if` canonicalizers. Just make sure to create `scf.if` with
non-empty regions.

This was split off from #5726
since others were hitting the crash.
  • Loading branch information
Mogball authored Feb 6, 2025
1 parent d56f4fe commit b71400a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
12 changes: 6 additions & 6 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,14 +703,14 @@ scf::IfOp replaceIfOpWithNewSignature(
// Create a new loop before the existing one, with the extra operands.
auto resultTypes = llvm::to_vector<4>(ifOp.getResults().getTypes());
resultTypes.append(newResultTypes.begin(), newResultTypes.end());
scf::IfOp newIf = rewriter.create<scf::IfOp>(
ifOp.getLoc(), resultTypes, ifOp.getCondition(), /*withElse=*/true);
scf::IfOp newIf = rewriter.create<scf::IfOp>(ifOp.getLoc(), resultTypes,
ifOp.getCondition());
newIf->setAttrs(ifOp->getAttrs());

rewriter.inlineBlockBefore(ifOp.thenBlock(), newIf.thenBlock(),
newIf.thenBlock()->begin());
rewriter.inlineBlockBefore(ifOp.elseBlock(), newIf.elseBlock(),
newIf.elseBlock()->begin());
newIf.getThenRegion().takeBody(ifOp.getThenRegion());
newIf.getElseRegion().takeBody(ifOp.getElseRegion());
scf::IfOp::ensureTerminator(newIf.getThenRegion(), rewriter, ifOp.getLoc());
scf::IfOp::ensureTerminator(newIf.getElseRegion(), rewriter, ifOp.getLoc());

for (auto it : llvm::zip(ifOp.getResults(),
newIf.getResults().take_front(ifOp.getNumResults())))
Expand Down
30 changes: 30 additions & 0 deletions test/TritonGPU/matmul-loop-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,33 @@ tt.func public @scalar_load(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: i32, %arg3:
}

}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90"} {

// CHECK-LABEL: @make_tensor_desc_epilogue
tt.func public @make_tensor_desc_epilogue(%arg0: i32, %arg1: !tt.ptr<f32>, %arg2: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c1_i64 = arith.constant 1 : i64
// CHECK: scf.for
scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 : i32 {
%1 = tt.splat %arg1 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : !tt.ptr<f32> -> tensor<128x256x!tt.ptr<f32>, #blocked>
%2 = tt.load %1 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : tensor<128x256x!tt.ptr<f32>, #blocked>
%3 = arith.addf %2, %2 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : tensor<128x256xf32, #blocked>
%4 = arith.cmpi eq, %arg3, %c1_i32 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : i32
// CHECK: scf.if
scf.if %4 {
// CHECK-NOT: tt.make_tensor_descriptor
// CHECK: tt.experimental_tensormap_create
// CHECK-NEXT: tt.experimental_tensormap_fenceproxy_acquire
%5 = tt.make_tensor_descriptor %arg1, [%arg2, %arg2], [%c1_i64, %c1_i64] : <f32>, <tensor<128x256xf32>>
} {loop.cluster = 5 : i32, loop.stage = 2 : i32}
} {tt.num_stages = 3 : i32}
tt.return
}

}

0 comments on commit b71400a

Please sign in to comment.