Skip to content

Commit

Permalink
[BACKEND] Promote tl.atomic_add to PTX ld.acquire when possible (#5187)
Browse files Browse the repository at this point in the history
To optimize the case tl.atomic_add(ptr, 0) for scalars, there is a new
path for lowering to PTX `ld.acquire.scope` (`.cta`, `.gpu`, `.sys`)

It does this by lowering to `nvgpu.ld_acquire` from the TTGIR::AtomicRMW
lowering, then subsequently lowering to an LLVM inline_ptx of
`ld.acquire` for NVGP::LoadAcquireOp lowering.

The purpose is to generate better code for synchronizing groups of
threads during a cooperative thread launch.

- [x] I am not making a trivial change, such as fixing a typo in a
comment.
- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.
    - `/python/test` for end-to-end tests
  • Loading branch information
plotfi authored Jan 24, 2025
1 parent ad16e3d commit 29912c0
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 1 deletion.
29 changes: 29 additions & 0 deletions test/Conversion/atomic_ldst.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s --check-prefix=CHECK-TTG2NVGPU
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 --convert-nv-gpu-to-llvm 2>&1 | FileCheck %s --check-prefix=CHECK-NVGPU2LLVM
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @kernel_r(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant 0.000000e+00 : f32
%true = arith.constant true
%c128_i32 = arith.constant 128 : i32
%c512_i32 = arith.constant 512 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c128_i32 : i32
%2 = arith.cmpi slt, %1, %c512_i32 : i32

// CHECK-TTG2NVGPU: nvgpu.ld_acquire acquire, gpu
// CHECK-NVGPU2LLVM: ld.global.gpu.acquire.b32
%3 = tt.atomic_rmw fadd, acquire, gpu, %arg0, %cst, %2 : (!tt.ptr<f32>, f32, i1) -> f32
tt.store %arg0, %3 : !tt.ptr<f32>

// CHECK-TTG2NVGPU: nvgpu.ld_acquire acquire, cta
// CHECK-NVGPU2LLVM: ld.global.cta.acquire.b32
%4 = tt.atomic_rmw fadd, acquire, cta, %arg0, %cst, %true : (!tt.ptr<f32>, f32, i1) -> f32
tt.store %arg0, %4 : !tt.ptr<f32>

// CHECK-TTG2NVGPU: nvgpu.ld_acquire acquire, sys
// CHECK-NVGPU2LLVM: ld.global.sys.acquire.b32
%5 = tt.atomic_rmw fadd, acquire, sys, %arg0, %cst, %2 : (!tt.ptr<f32>, f32, i1) -> f32
tt.store %arg0, %5 : !tt.ptr<f32>
tt.return
}
}
38 changes: 38 additions & 0 deletions third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,33 @@ include "NVGPUAttrDefs.td"
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;


def NVGPU_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
def NVGPU_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
def NVGPU_ScalarLike : AnyTypeOf<[NVGPU_Float, NVGPU_Int]>;


def NVGPU_MemSemanticAttr : I32EnumAttr<
"MemSemantic", "",
[
I32EnumAttrCase<"RELAXED", 1, "relaxed">,
I32EnumAttrCase<"ACQUIRE", 2, "acquire">,
I32EnumAttrCase<"RELEASE", 3, "release">,
I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">,
]> {
let cppNamespace = "::mlir::triton::nvgpu";
}

def NVGPU_MemSyncScopeAttr : I32EnumAttr<
"MemSyncScope", "",
[
I32EnumAttrCase<"GPU", 1, "gpu">,
I32EnumAttrCase<"CTA", 2, "cta">,
I32EnumAttrCase<"SYSTEM", 3, "sys">,
]> {
let cppNamespace = "::mlir::triton::nvgpu";
}

class NVGPU_Op<string mnemonic, list<Trait> traits = []> :
LLVM_OpBase<NVGPU_Dialect, mnemonic, traits>;

Expand Down Expand Up @@ -123,4 +150,15 @@ def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> {
let assemblyFormat = "attr-dict";
}

def NVGPU_LoadAcquireOp : NVGPU_Op<"ld_acquire", [MemoryEffects<[MemRead]>]> {
let arguments = (
ins LLVM_PointerGlobal:$addr,
Optional<I1>:$mask,
NVGPU_MemSemanticAttr:$sem,
NVGPU_MemSyncScopeAttr:$scope
);
let results = (outs NVGPU_ScalarLike:$result);
let assemblyFormat = "$sem `,` $scope `,` $addr (`,` $mask^)? attr-dict `:` functional-type($addr, $result)";
}

#endif
42 changes: 41 additions & 1 deletion third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,46 @@ class LoadMatrixOpPattern
}
};

class LoadAcquireOpPattern : public OpRewritePattern<ttn::LoadAcquireOp> {
public:
using OpRewritePattern<ttn::LoadAcquireOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ttn::LoadAcquireOp op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
Type valueTy = op.getType();
const unsigned valueNBits = std::max(8u, valueTy.getIntOrFloatBitWidth());
const size_t maxWordWidth = std::max<size_t>(32, valueNBits);
const size_t width = std::min((size_t)valueNBits, maxWordWidth);

const std::string writeConstraint =
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
PTXBuilder ptxBuilder;
bool init = true;
auto *dstOpr = ptxBuilder.newOperand(writeConstraint, init); // =r operation
auto *addrOpr =
ptxBuilder.newAddrOperand(op.getAddr(), "l", 0 /* in_off */);
auto &ld =
ptxBuilder.create<>("ld")
->global()
.o("cta", op.getScope() == triton::nvgpu::MemSyncScope::CTA)
.o("gpu", op.getScope() == triton::nvgpu::MemSyncScope::GPU)
.o("sys", op.getScope() == triton::nvgpu::MemSyncScope::SYSTEM)
.o("acquire", op.getSem() == triton::nvgpu::MemSemantic::ACQUIRE)
.o("relaxed", op.getSem() == triton::nvgpu::MemSemantic::RELAXED)
.b(width);
ld(dstOpr, addrOpr).maybePredicate(op.getMask(), "b");

// Create inline ASM signature
Type retTy = IntegerType::get(getContext(), width);
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
ret = bitcast(ret, op.getType());

rewriter.replaceOp(op, {ret});
return success();
}
};

class WGMMAWaitGroupOpPattern : public OpRewritePattern<ttn::WGMMAWaitGroupOp> {
public:
using OpRewritePattern<ttn::WGMMAWaitGroupOp>::OpRewritePattern;
Expand Down Expand Up @@ -608,7 +648,7 @@ class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase<ConvertNVGPUToLLVM> {

patterns.add<FenceAsyncSharedOpPattern, LoadMatrixOpPattern,
StoreMatrixOpPattern, ClusterArriveOpPattern, WGMMAOpPattern,
WGMMAWaitGroupOpPattern>(context);
LoadAcquireOpPattern, WGMMAWaitGroupOpPattern>(context);

if (applyPatternsGreedily(mod, std::move(patterns)).failed())
signalPassFailure();
Expand Down
84 changes: 84 additions & 0 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "Dialect/NVGPU/IR/Dialect.h"
#include "TargetInfo.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
Expand All @@ -24,6 +25,9 @@ using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;

// Toggle this to work around Cooperative Grid Launch ld.acquire optimized path
static constexpr bool disableLDAcquireLowering = false;

namespace {

llvm::MapVector<StringAttr, int32_t> getAllFreeVarMasks(MLIRContext *ctx) {
Expand Down Expand Up @@ -696,6 +700,48 @@ struct AtomicRMWOpConversion
(elementType.isF16() || elementType.isBF16() || elementType.isF32());
}

bool isPromotableToNVPTXLD(triton::AtomicRMWOp op) const {
if (disableLDAcquireLowering)
return false;

Type valueTy =
getTypeConverter()->convertType(getElementTypeOrSelf(op.getType()));

if (!valueTy.isIntOrFloat())
return false;
if (op.getSem() != triton::MemSemantic::ACQUIRE &&
op.getSem() != triton::MemSemantic::RELAXED)
return false;
if (op.getScope() != triton::MemSyncScope::CTA &&
op.getScope() != triton::MemSyncScope::GPU &&
op.getScope() != triton::MemSyncScope::SYSTEM)
return false;

if (op.getAtomicRmwOp() != RMWOp::ADD && op.getAtomicRmwOp() != RMWOp::FADD)
return false;
if (isa<RankedTensorType>(op.getType()))
return false;
if (!op.getVal().getDefiningOp())
return false;
if (!isa<arith::ConstantOp>(op.getVal().getDefiningOp()))
return false;

auto constOp = cast<arith::ConstantOp>(op.getVal().getDefiningOp());
if (!isa<FloatAttr>(constOp.getValueAttr()) &&
!isa<IntegerAttr>(constOp.getValueAttr()))
return false;

if (auto attr = dyn_cast_or_null<FloatAttr>(constOp.getValueAttr()))
if (!attr.getValue().isZero())
return false;

if (auto attr = dyn_cast_or_null<IntegerAttr>(constOp.getValueAttr()))
if (!attr.getValue().isZero())
return false;

return true;
}

LogicalResult
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -767,6 +813,17 @@ struct AtomicRMWOpConversion

auto packedTy = vec_ty(valueElemTy, packed);
SmallVector<Value> resultVals(elemsPerThread);

// Lower AtomicRMWOp to a ld.acquire if possible
std::unordered_map<triton::MemSyncScope, triton::nvgpu::MemSyncScope>
ScopeMap = {
{triton::MemSyncScope::CTA, triton::nvgpu::MemSyncScope::CTA},
{triton::MemSyncScope::GPU, triton::nvgpu::MemSyncScope::GPU},
{triton::MemSyncScope::SYSTEM,
triton::nvgpu::MemSyncScope::SYSTEM}};
const bool doPTXLDPromotion = isPromotableToNVPTXLD(op) && vec == 1 &&
packed == 1 && ScopeMap.count(op.getScope());

for (size_t i = 0; i < elemsPerThread; i += vec * packed) {
if (auto canonicalStart = getCanonicalIndex(i, regMask);
canonicalStart != i) {
Expand All @@ -780,6 +837,33 @@ struct AtomicRMWOpConversion
Value rmwPtr = ptrElements[i];
Value pred = llMask ? maybeAnd(rewriter, loc, threadPred, maskElements[i])
: threadPred;

if (doPTXLDPromotion) {
Type covertedValueTy =
getTypeConverter()->convertType(getElementTypeOrSelf(op.getType()));
auto loadAcquireOp = rewriter.create<triton::nvgpu::LoadAcquireOp>(
op.getLoc(), covertedValueTy, rmwPtr, pred,
op.getSem() == triton::MemSemantic::ACQUIRE
? triton::nvgpu::MemSemantic::ACQUIRE
: triton::nvgpu::MemSemantic::RELAXED,
ScopeMap[op.getScope()]);

auto ASMReturnTy = void_ty(ctx);
if (!atomicNeedsSharedMemory(op.getResult())) {
rewriter.eraseOp(op);
return success();
}
Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo,
op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3));
// Only threads with rmwMask = True store the result
targetInfo.storeShared(rewriter, loc, atomPtr, loadAcquireOp, pred);
createBarrier(rewriter, loc, numCTAs);
Value ret = load(valueElemTy, atomPtr);
rewriter.replaceOp(op, {ret});
continue;
}

std::string sTy;
PTXBuilder ptxBuilderAtomicRMW;
// 16-bit -> "h", 32-bit -> "r", 64-bit -> "l"
Expand Down

0 comments on commit 29912c0

Please sign in to comment.