From 29912c074be9d6d91e6a7161d8bb41c0df2deb72 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi <34139736+plotfi@users.noreply.github.com> Date: Fri, 24 Jan 2025 13:20:36 -0800 Subject: [PATCH] [BACKEND] Promote tl.atomic_add to PTX ld.acquire when possible (#5187) To optimize the case tl.atomic_add(ptr, 0) for scalars, there is a new path for lowering to PTX `ld.acquire.scope` (`.cta`, `.gpu`, `.sys`) It does this by lowering to `nvgpu.ld_acquire` from the TTGIR::AtomicRMW lowering, then subsequently lowering to an LLVM inline_ptx of `ld.acquire` for NVGP::LoadAcquireOp lowering. The purpose is to generate better code for synchronizing groups of threads during a cooperative thread launch. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/python/test` for end-to-end tests --- test/Conversion/atomic_ldst.mlir | 29 +++++++ .../include/Dialect/NVGPU/IR/NVGPUOps.td | 38 +++++++++ .../lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 42 +++++++++- .../LoadStoreOpToLLVM.cpp | 84 +++++++++++++++++++ 4 files changed, 192 insertions(+), 1 deletion(-) create mode 100644 test/Conversion/atomic_ldst.mlir diff --git a/test/Conversion/atomic_ldst.mlir b/test/Conversion/atomic_ldst.mlir new file mode 100644 index 000000000000..4c1e63c407ce --- /dev/null +++ b/test/Conversion/atomic_ldst.mlir @@ -0,0 +1,29 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s --check-prefix=CHECK-TTG2NVGPU +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 --convert-nv-gpu-to-llvm 2>&1 | FileCheck %s --check-prefix=CHECK-NVGPU2LLVM +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @kernel_r(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : f32 + %true = arith.constant true + %c128_i32 = arith.constant 128 : i32 + %c512_i32 = arith.constant 512 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = arith.cmpi slt, %1, %c512_i32 : i32 + + // CHECK-TTG2NVGPU: nvgpu.ld_acquire acquire, gpu + // CHECK-NVGPU2LLVM: ld.global.gpu.acquire.b32 + %3 = tt.atomic_rmw fadd, acquire, gpu, %arg0, %cst, %2 : (!tt.ptr, f32, i1) -> f32 + tt.store %arg0, %3 : !tt.ptr + + // CHECK-TTG2NVGPU: nvgpu.ld_acquire acquire, cta + // CHECK-NVGPU2LLVM: ld.global.cta.acquire.b32 + %4 = tt.atomic_rmw fadd, acquire, cta, %arg0, %cst, %true : (!tt.ptr, f32, i1) -> f32 + tt.store %arg0, %4 : !tt.ptr + + // CHECK-TTG2NVGPU: nvgpu.ld_acquire acquire, sys + // CHECK-NVGPU2LLVM: ld.global.sys.acquire.b32 + %5 = tt.atomic_rmw fadd, acquire, sys, %arg0, %cst, %2 : (!tt.ptr, f32, i1) -> f32 + tt.store %arg0, %5 : !tt.ptr + tt.return + } +} diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td index ed2a2ec39175..2508fa22fa18 100644 --- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -32,6 +32,33 @@ include "NVGPUAttrDefs.td" def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>; def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>; + +def NVGPU_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def NVGPU_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">; +def NVGPU_ScalarLike : AnyTypeOf<[NVGPU_Float, NVGPU_Int]>; + + +def NVGPU_MemSemanticAttr : I32EnumAttr< + "MemSemantic", "", + [ + I32EnumAttrCase<"RELAXED", 1, "relaxed">, + I32EnumAttrCase<"ACQUIRE", 2, "acquire">, + I32EnumAttrCase<"RELEASE", 3, "release">, + I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">, + ]> { + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def NVGPU_MemSyncScopeAttr : I32EnumAttr< + "MemSyncScope", "", + [ + I32EnumAttrCase<"GPU", 1, "gpu">, + I32EnumAttrCase<"CTA", 2, "cta">, + I32EnumAttrCase<"SYSTEM", 3, "sys">, + ]> { + let cppNamespace = "::mlir::triton::nvgpu"; +} + class NVGPU_Op traits = []> : LLVM_OpBase; @@ -123,4 +150,15 @@ def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> { let assemblyFormat = "attr-dict"; } +def NVGPU_LoadAcquireOp : NVGPU_Op<"ld_acquire", [MemoryEffects<[MemRead]>]> { + let arguments = ( + ins LLVM_PointerGlobal:$addr, + Optional:$mask, + NVGPU_MemSemanticAttr:$sem, + NVGPU_MemSyncScopeAttr:$scope + ); + let results = (outs NVGPU_ScalarLike:$result); + let assemblyFormat = "$sem `,` $scope `,` $addr (`,` $mask^)? attr-dict `:` functional-type($addr, $result)"; +} + #endif diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 3758f68ed9cb..cfd3acc141ec 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -370,6 +370,46 @@ class LoadMatrixOpPattern } }; +class LoadAcquireOpPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::LoadAcquireOp op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + Type valueTy = op.getType(); + const unsigned valueNBits = std::max(8u, valueTy.getIntOrFloatBitWidth()); + const size_t maxWordWidth = std::max(32, valueNBits); + const size_t width = std::min((size_t)valueNBits, maxWordWidth); + + const std::string writeConstraint = + (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); + PTXBuilder ptxBuilder; + bool init = true; + auto *dstOpr = ptxBuilder.newOperand(writeConstraint, init); // =r operation + auto *addrOpr = + ptxBuilder.newAddrOperand(op.getAddr(), "l", 0 /* in_off */); + auto &ld = + ptxBuilder.create<>("ld") + ->global() + .o("cta", op.getScope() == triton::nvgpu::MemSyncScope::CTA) + .o("gpu", op.getScope() == triton::nvgpu::MemSyncScope::GPU) + .o("sys", op.getScope() == triton::nvgpu::MemSyncScope::SYSTEM) + .o("acquire", op.getSem() == triton::nvgpu::MemSemantic::ACQUIRE) + .o("relaxed", op.getSem() == triton::nvgpu::MemSemantic::RELAXED) + .b(width); + ld(dstOpr, addrOpr).maybePredicate(op.getMask(), "b"); + + // Create inline ASM signature + Type retTy = IntegerType::get(getContext(), width); + Value ret = ptxBuilder.launch(rewriter, loc, retTy); + ret = bitcast(ret, op.getType()); + + rewriter.replaceOp(op, {ret}); + return success(); + } +}; + class WGMMAWaitGroupOpPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -608,7 +648,7 @@ class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase { patterns.add(context); + LoadAcquireOpPattern, WGMMAWaitGroupOpPattern>(context); if (applyPatternsGreedily(mod, std::move(patterns)).failed()) signalPassFailure(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 593e866ea905..5189a40d618f 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1,3 +1,4 @@ +#include "Dialect/NVGPU/IR/Dialect.h" #include "TargetInfo.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" @@ -24,6 +25,9 @@ using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; +// Toggle this to work around Cooperative Grid Launch ld.acquire optimized path +static constexpr bool disableLDAcquireLowering = false; + namespace { llvm::MapVector getAllFreeVarMasks(MLIRContext *ctx) { @@ -696,6 +700,48 @@ struct AtomicRMWOpConversion (elementType.isF16() || elementType.isBF16() || elementType.isF32()); } + bool isPromotableToNVPTXLD(triton::AtomicRMWOp op) const { + if (disableLDAcquireLowering) + return false; + + Type valueTy = + getTypeConverter()->convertType(getElementTypeOrSelf(op.getType())); + + if (!valueTy.isIntOrFloat()) + return false; + if (op.getSem() != triton::MemSemantic::ACQUIRE && + op.getSem() != triton::MemSemantic::RELAXED) + return false; + if (op.getScope() != triton::MemSyncScope::CTA && + op.getScope() != triton::MemSyncScope::GPU && + op.getScope() != triton::MemSyncScope::SYSTEM) + return false; + + if (op.getAtomicRmwOp() != RMWOp::ADD && op.getAtomicRmwOp() != RMWOp::FADD) + return false; + if (isa(op.getType())) + return false; + if (!op.getVal().getDefiningOp()) + return false; + if (!isa(op.getVal().getDefiningOp())) + return false; + + auto constOp = cast(op.getVal().getDefiningOp()); + if (!isa(constOp.getValueAttr()) && + !isa(constOp.getValueAttr())) + return false; + + if (auto attr = dyn_cast_or_null(constOp.getValueAttr())) + if (!attr.getValue().isZero()) + return false; + + if (auto attr = dyn_cast_or_null(constOp.getValueAttr())) + if (!attr.getValue().isZero()) + return false; + + return true; + } + LogicalResult matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -767,6 +813,17 @@ struct AtomicRMWOpConversion auto packedTy = vec_ty(valueElemTy, packed); SmallVector resultVals(elemsPerThread); + + // Lower AtomicRMWOp to a ld.acquire if possible + std::unordered_map + ScopeMap = { + {triton::MemSyncScope::CTA, triton::nvgpu::MemSyncScope::CTA}, + {triton::MemSyncScope::GPU, triton::nvgpu::MemSyncScope::GPU}, + {triton::MemSyncScope::SYSTEM, + triton::nvgpu::MemSyncScope::SYSTEM}}; + const bool doPTXLDPromotion = isPromotableToNVPTXLD(op) && vec == 1 && + packed == 1 && ScopeMap.count(op.getScope()); + for (size_t i = 0; i < elemsPerThread; i += vec * packed) { if (auto canonicalStart = getCanonicalIndex(i, regMask); canonicalStart != i) { @@ -780,6 +837,33 @@ struct AtomicRMWOpConversion Value rmwPtr = ptrElements[i]; Value pred = llMask ? maybeAnd(rewriter, loc, threadPred, maskElements[i]) : threadPred; + + if (doPTXLDPromotion) { + Type covertedValueTy = + getTypeConverter()->convertType(getElementTypeOrSelf(op.getType())); + auto loadAcquireOp = rewriter.create( + op.getLoc(), covertedValueTy, rmwPtr, pred, + op.getSem() == triton::MemSemantic::ACQUIRE + ? triton::nvgpu::MemSemantic::ACQUIRE + : triton::nvgpu::MemSemantic::RELAXED, + ScopeMap[op.getScope()]); + + auto ASMReturnTy = void_ty(ctx); + if (!atomicNeedsSharedMemory(op.getResult())) { + rewriter.eraseOp(op); + return success(); + } + Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, + op.getOperation()); + atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3)); + // Only threads with rmwMask = True store the result + targetInfo.storeShared(rewriter, loc, atomPtr, loadAcquireOp, pred); + createBarrier(rewriter, loc, numCTAs); + Value ret = load(valueElemTy, atomPtr); + rewriter.replaceOp(op, {ret}); + continue; + } + std::string sTy; PTXBuilder ptxBuilderAtomicRMW; // 16-bit -> "h", 32-bit -> "r", 64-bit -> "l"