From b9dfacc33a6e1ef03e4fa78360e0b5284c376533 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Fri, 24 Jan 2025 09:52:21 +0100 Subject: [PATCH 1/2] Don't use cpu for tensor comparison in `python/tutorials/01-vector-add.py` (#3242) Signed-off-by: Anatoly Myachev --- python/tutorials/01-vector-add.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 1e77ca7c1d..e527e5fc7a 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -87,10 +87,10 @@ def add(x: torch.Tensor, y: torch.Tensor): y = torch.rand(size, device=DEVICE) output_torch = x + y output_triton = add(x, y) -print(output_torch.cpu()) -print(output_triton.cpu()) +print(output_torch) +print(output_triton) print(f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch.cpu() - output_triton.cpu()))}') + f'{torch.max(torch.abs(output_torch - output_triton))}') # %% # Seems like we're good to go! From b018ed69e9cfd590afe05537024aef27dc76b5cb Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Fri, 24 Jan 2025 09:40:48 -0500 Subject: [PATCH 2/2] [triton-raise-block-ptr]: `scf.for` init arg list rewrite causes invalid IR in `tt.broadcast` contained inside the loop body (#3252) Fixes issue #3254 --------- Signed-off-by: Tiotto, Ettore --- .../RaiseToBlockPointers/addptr_dim1.mlir | 100 ++++++++++++++++++ .../TritonRaiseBlockPointer.cpp | 38 +++++-- 2 files changed, 127 insertions(+), 11 deletions(-) create mode 100644 test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir diff --git a/test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir b/test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir new file mode 100644 index 0000000000..a9549c583f --- /dev/null +++ b/test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir @@ -0,0 +1,100 @@ +// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : i32 + ) + { + %0 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + + %splat_arg0 = tt.splat %arg0 : !tt.ptr -> tensor<1x256x!tt.ptr> + %2 = tt.addptr %splat_arg0, %1 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + // source = %arg0, offset = [0, 0], size = [1, 256], stride = [0, 1] + + // 1x256 pointer should have meaningful stride in outer dimension + %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<1x256x!tt.ptr> + + %4 = tt.splat %arg1 : i32 -> tensor<1x256xi32> + // 1x256 pointer should have meaningful stride in outer dimension + %5 = tt.addptr %2, %4 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + // source = %arg0, offset = [%arg1, 0], size = [1, 256], stride = [0, 1] + + tt.store %5, %3 : tensor<1x256x!tt.ptr> + + %10 = arith.constant 0.0 : bf16 + %11 = tt.splat %10 : bf16 -> tensor<4x256xbf16> + + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %c256 = arith.constant 256 : i32 + %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %11, %ptr = %2) -> (tensor<4x256xbf16>, tensor<1x256x!tt.ptr>) { + %bptr = tt.broadcast %ptr : tensor<1x256x!tt.ptr> -> tensor<4x256x!tt.ptr> + // source = %arg0, offset = [0, 0], size = [4, 256], stride = [0, 1] + + %20 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + %i_i32 = arith.index_cast %i : index to i32 + %21 = arith.muli %c256, %i_i32 : i32 + %22 = tt.splat %21 : i32 -> tensor<4xi32> + %23 = arith.muli %20, %22 : tensor<4xi32> + %24 = tt.expand_dims %23 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %25 = tt.broadcast %24 : tensor<4x1xi32> -> tensor<4x256xi32> + // offset = [0, 0], size = [4, 256], stride = [i*256, 1] + + // %bptr should have zero stride and %30 should have correct stride + %30 = tt.addptr %bptr, %25 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source = %arg0, offset = [0, 0], size = [4, 256], stride = [i*256, 1] + + %31 = tt.load %30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> + %32 = arith.addf %sum_iter, %31 : tensor<4x256xbf16> + + %40 = tt.splat %c256 : i32 -> tensor<1x256xi32> + %41 = tt.addptr %ptr, %40 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + // source = %arg0, offset = [i*256, 0], size = [4, 256], stride = [i*256, 1] + + scf.yield %32, %41 : tensor<4x256xbf16>, tensor<1x256x!tt.ptr> + } + + %31 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + %splat_c256 = tt.splat %c256 : i32 -> tensor<4xi32> + %32 = arith.muli %31, %splat_c256 : tensor<4xi32> + %33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %34 = tt.broadcast %33 : tensor<4x1xi32> -> tensor<4x256xi32> + %35 = tt.broadcast %2 : tensor<1x256x!tt.ptr> -> tensor<4x256x!tt.ptr> + %36 = tt.addptr %35, %34 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + tt.store %36, %sum_out : tensor<4x256x!tt.ptr> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: i32) { +// CHECK-DAG: [[CST_:%.+]] = arith.constant dense<256> : tensor<1x256xi32> +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant dense<0.000000e+00> : tensor<4x256xbf16> +// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 +// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 +// CHECK-DAG: [[CST_256_i64:%.+]] = arith.constant 256 : i64 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> +// CHECK-DAG: [[VAR_1_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = tt.splat [[PARAM_0_]] : !tt.ptr -> tensor<1x256x!tt.ptr> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_0_i64]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > +// CHECK-DAG: [[VAR_4_:%.+]] = tt.addptr [[VAR_2_]], [[VAR_1_]] : tensor<1x256x!tt.ptr>, tensor<1x256xi32> +// CHECK-DAG: [[VAR_5_:%.+]] = tt.load [[VAR_3_]] : !tt.ptr> +// CHECK-DAG: [[VAR_6_:%.+]] = tt.advance [[VAR_3_]], {{\[}}[[CST_0_i32]], [[PARAM_1_]]] : > +// CHECK: tt.store [[VAR_6_]], [[VAR_5_]] : !tt.ptr> +// CHECK: [[VAR_7_:%.+]]:2 = scf.for [[VAR_arg2_:%.+]] = {{.*}} iter_args([[VAR_arg3_:%.+]] = [[CST_0_]], [[VAR_arg4_:%.+]] = [[VAR_4_]]) -> (tensor<4x256xbf16>, tensor<1x256x!tt.ptr>) { +// CHECK: [[VAR_9_:%.+]] = tt.broadcast [[VAR_arg4_]] : tensor<1x256x!tt.ptr> -> tensor<4x256x!tt.ptr> +// CHECK-NOT: tt.make_tensor_ptr +// CHECK-NOT: tt.advance +// CHECK: [[VAR_20_:%.+]] = tt.addptr [[VAR_arg4_]], [[CST_]] : tensor<1x256x!tt.ptr>, tensor<1x256xi32> +// CHECK: scf.yield {{.*}}, [[VAR_20_]] : tensor<4x256xbf16>, tensor<1x256x!tt.ptr> +// CHECK: } +// CHECK: [[VAR_8_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_256_i64]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > +// CHECK: tt.store [[VAR_8_]], [[VAR_7_]]#0 : !tt.ptr> +// CHECK: tt.return +// CHECK: } diff --git a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp index e78cecbdf7..0b4c4ff5c5 100644 --- a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp +++ b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Verifier.h" +#include "mlir/Support/LLVM.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -435,7 +436,7 @@ struct TritonRaiseBlockPointer if (failed(rewriteForOp(forOp))) { forOp->emitRemark( "TritonRaiseToBlockPointer: Failed to rewrite ForOp"); - return WalkResult::interrupt(); + return WalkResult::advance(); } return WalkResult::skip(); }) @@ -452,17 +453,24 @@ struct TritonRaiseBlockPointer SmallVector> initArgIndex; OpBuilder builder(op); + auto canBeRewrittenUsingBlockPtr = [&](Operation *op) { + return TypeSwitch(op) + .Case( + [](auto) { return true; }) + .Default([](auto) { return false; }); + }; + // Create a new list of init args for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { if (Value mappedV = ptrMap.lookupOrNull(arg)) { if (auto makeTensorPtrOp = mappedV.getDefiningOp()) { if (llvm::any_of(op.getRegionIterArgs()[i].getUsers(), - [](Operation *user) { - return isa(user); + [&](Operation *user) { + return !canBeRewrittenUsingBlockPtr(user); })) { - op->emitRemark("TritonRaiseToBlockPointer: ExpandDims Ops in loops " - "are currently not supported"); + op->emitRemark("TritonRaiseToBlockPointer: Loop contains ops that " + "cannot be rewritten using a block ptr"); return failure(); } @@ -668,7 +676,7 @@ struct TritonRaiseBlockPointer OpBuilder builder(op); Location loc = op.getLoc(); - auto ptr = op.getPtr(); + Value ptr = op.getPtr(); auto fillOffsets = [&](Value offset, unsigned rank, SmallVector &offsets) { @@ -726,11 +734,16 @@ struct TritonRaiseBlockPointer assert(!offsets.empty() && offsets.size() == rank && "unexpected number of offsets"); - auto advanceOp = builder.createOrFold(loc, ptr.getType(), - ptr, offsets); - cleanUp.push_back(op); + + Value basePtr = tt::isTensorPointerType(ptr.getType()) ? ptr : mappedV; + auto advanceOp = builder.createOrFold( + loc, basePtr.getType(), basePtr, offsets); + + cleanUp.insert(op); ptrMap.map(op.getResult(), advanceOp); + LLVM_DEBUG(llvm::dbgs() + << "Rewrote:\n\t" << op << "to:\n\t" << advanceOp << "\n"); return success(); } else { llvm_unreachable("Did not find tt::MakeTensorPtrOp"); @@ -755,9 +768,12 @@ struct TritonRaiseBlockPointer ptrMap.map(result, makePtrOp); + LLVM_DEBUG(llvm::dbgs() + << "Rewrote:\n\t" << op << "\nto:\n\t" << makePtrOp << "\n"); + // AddPtrOps that have been rewritten and no longer used in the code must // be removed in the pass to avoid type matching issue. - cleanUp.push_back(op); + cleanUp.insert(op); LLVM_DEBUG({ auto modOp = @@ -1039,7 +1055,7 @@ struct TritonRaiseBlockPointer } private: - SmallVector cleanUp; + SmallPtrSet cleanUp; llvm::SmallDenseMap knownPtrs; IRMapping ptrMap; };