From 811e97acb83b75ff842c5f84c792917ebce64f95 Mon Sep 17 00:00:00 2001 From: Pawel Szczerbuk Date: Fri, 7 Feb 2025 15:52:24 -0800 Subject: [PATCH] Moving to Unit Attr instead of Bool Attr --- .../TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp | 5 +---- python/triton/compiler/code_generator.py | 2 +- test/TritonGPU/mma-pipeline-blackwell.mlir | 6 +++--- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index 4910c651bd51..6df1c31f3855 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -237,10 +237,7 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, // Return true if the given ForOp has the attribute // `tt.disallow_acc_multi_buffer` set to true. bool mlir::triton::getDisallowAccMultiBuffer(scf::ForOp forOp) { - return forOp->hasAttr(mlir::triton::kDisallowAccMultiBufferAttrName) && - cast( - forOp->getAttr(mlir::triton::kDisallowAccMultiBufferAttrName)) - .getValue(); + return forOp->hasAttr(mlir::triton::kDisallowAccMultiBufferAttrName); } std::optional> diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 17614175aabd..61a2a1d2b700 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -1087,7 +1087,7 @@ def visit_For(self, node): if loop_unroll_factor is not None: for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) if disallow_acc_multi_buffer: - for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_bool_attr(True)) + for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr()) if flatten: for_op.set_attr("tt.flatten", self.builder.get_unit_attr()) diff --git a/test/TritonGPU/mma-pipeline-blackwell.mlir b/test/TritonGPU/mma-pipeline-blackwell.mlir index 0782c8f063b4..d310eee6b810 100644 --- a/test/TritonGPU/mma-pipeline-blackwell.mlir +++ b/test/TritonGPU/mma-pipeline-blackwell.mlir @@ -251,7 +251,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr, #blocked> - } {tt.disallow_acc_multi_buffer = true} + } {tt.disallow_acc_multi_buffer} ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> tt.return @@ -533,7 +533,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr, #blocked> } scf.yield %new_acc : tensor<128x128xf32, #blocked> - } {tt.disallow_acc_multi_buffer = true} + } {tt.disallow_acc_multi_buffer} ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> tt.return @@ -675,7 +675,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr, #blocked> } scf.yield %acc_res, %new_accUse : tensor<128x128xf32, #blocked>, i1 - } {tt.disallow_acc_multi_buffer = true} + } {tt.disallow_acc_multi_buffer} ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> tt.return