Skip to content

Commit

Permalink
Moving to Unit Attr instead of Bool Attr
Browse files Browse the repository at this point in the history
  • Loading branch information
pawelszczerbuk committed Feb 7, 2025
1 parent aa28b07 commit 811e97a
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,7 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder,
// Return true if the given ForOp has the attribute
// `tt.disallow_acc_multi_buffer` set to true.
bool mlir::triton::getDisallowAccMultiBuffer(scf::ForOp forOp) {
return forOp->hasAttr(mlir::triton::kDisallowAccMultiBufferAttrName) &&
cast<BoolAttr>(
forOp->getAttr(mlir::triton::kDisallowAccMultiBufferAttrName))
.getValue();
return forOp->hasAttr(mlir::triton::kDisallowAccMultiBufferAttrName);
}

std::optional<std::pair<int, int>>
Expand Down
2 changes: 1 addition & 1 deletion python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,7 @@ def visit_For(self, node):
if loop_unroll_factor is not None:
for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor))
if disallow_acc_multi_buffer:
for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_bool_attr(True))
for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr())
if flatten:
for_op.set_attr("tt.flatten", self.builder.get_unit_attr())

Expand Down
6 changes: 3 additions & 3 deletions test/TritonGPU/mma-pipeline-blackwell.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> ()
%acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr<f32>, #blocked>
} {tt.disallow_acc_multi_buffer = true}
} {tt.disallow_acc_multi_buffer}
ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable>
ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable>
tt.return
Expand Down Expand Up @@ -533,7 +533,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr<f32>, #blocked>
}
scf.yield %new_acc : tensor<128x128xf32, #blocked>
} {tt.disallow_acc_multi_buffer = true}
} {tt.disallow_acc_multi_buffer}
ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable>
ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable>
tt.return
Expand Down Expand Up @@ -675,7 +675,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr<f32>, #blocked>
}
scf.yield %acc_res, %new_accUse : tensor<128x128xf32, #blocked>, i1
} {tt.disallow_acc_multi_buffer = true}
} {tt.disallow_acc_multi_buffer}
ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable>
ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable>
tt.return
Expand Down

0 comments on commit 811e97a

Please sign in to comment.