Skip to content

Commit

Permalink
[AMD] Count llvm instruction during conversion for scheduling hints (#…
Browse files Browse the repository at this point in the history
…4819)

Advanced software pipelining may require fine-grained adjustments
regarding instruction scheduling in the main `tt.dot` loop to achieve
higher performance. Such adjustments require detailed information
regarding the number of issued `v_mfma`, `ds_read`, `ds_write` and
`global_load`, instructions. This PR extends the Triton AMDGPU backend
by adding instruction counting during `TritonAMDGPUToLLVM` pass
execution.

An example of instruction counting and instruction scheduling is
demonstrated in the `createCKV3Schedule` method which implements the
[CK's V3 software
pipelining](https://github.com/ROCm/composable_kernel/blob/de3e3b642402eac5b4a466f6a2fa5e9f022ba680/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp#L160-L263).

This change is experimental for better GEMM performance. The design
is not final and may subject to change in the future.
  • Loading branch information
ravil-mobile authored Oct 13, 2024
1 parent 8966e5c commit e87f877
Show file tree
Hide file tree
Showing 21 changed files with 715 additions and 63 deletions.
2 changes: 2 additions & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUStreamPipeline();
mlir::registerTritonAMDGPUStreamPipelineV2();
mlir::registerTritonAMDGPUCanonicalizePointers();
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,33 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
constexpr int patternBenefitClampOptimizedPattern = 20;
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;

struct BackendCallbacks {
/**
* A backend-specific callback for appending auxiliary data during
* `LocalStoreOp` conversion.
*
* @param[in] op The reference to the re-written `LocalStoreOp`.
* @param[in] count The number of issued LLVM instructions.
* @param[in] type The input type of issued LLVM instructions.
*/
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
Type llvmOpType)>
localStoreOpConversion = nullptr;
};

void populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
RewritePatternSet &patterns,
PatternBenefit benefit);
// The given callback is invoked at the end of a successful rewrite. The
// callback receives 1) the current source op, 2) the number of issued LLVM
// instructions and 3) their input types. Each MLIR backend can provide a
// callback and, thus, handle backend-specific behaviors.
void populateMemoryOpToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit,
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);

void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
Expand Down
10 changes: 5 additions & 5 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1366,11 +1366,11 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);

void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
Type elemLlvmTy, ArrayRef<Value> srcVals,
Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);
void storeDistributedToShared(
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount = nullptr);

inline Value getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
Expand Down
36 changes: 25 additions & 11 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ using namespace mlir::triton::gpu;
// blocked -> shared.
// Swizzling in shared memory to avoid bank conflict. Normally used for
// A/B operands of dots.
void lowerDistributedToShared(Location loc, Value src, Value dst,
Value adaptorSrc,
const SharedMemoryObject &smemObj,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
const TargetInfoBase &targetInfo) {
void lowerDistributedToShared(
Location loc, Value src, Value dst, Value adaptorSrc,
const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo,
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<MemDescType>(dst.getType());
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
Expand All @@ -33,7 +32,7 @@ void lowerDistributedToShared(Location loc, Value src, Value dst,
auto dstStrides = smemObj.getStrides();
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides,
loc, rewriter, targetInfo);
loc, rewriter, targetInfo, llvmOpCount);
}

struct LocalAllocOpConversion
Expand Down Expand Up @@ -185,12 +184,15 @@ struct LocalStoreOpConversion
public:
using ConvertOpToLLVMPattern<
triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern;
using BackendCallbackType =
decltype(BackendCallbacks::localStoreOpConversion);

LocalStoreOpConversion(const LLVMTypeConverter &converter,
const TargetInfoBase &targetInfo,
BackendCallbackType backendCallback,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp>(converter, benefit),
targetInfo(targetInfo) {}
targetInfo(targetInfo), backendCallback(backendCallback) {}

LogicalResult
matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
Expand All @@ -200,24 +202,36 @@ struct LocalStoreOpConversion
getTypeConverter()->convertType(op.getDst().getType().getElementType());
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter);

std::pair<size_t, Type> llvmOpCount;
lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(),
adaptor.getSrc(), smemObj, getTypeConverter(),
rewriter, targetInfo);
rewriter, targetInfo, &llvmOpCount);

if (backendCallback)
(backendCallback)(op, llvmOpCount.first, llvmOpCount.second);

rewriter.eraseOp(op);
return success();
}

private:
const TargetInfoBase &targetInfo;
BackendCallbackType backendCallback;
};

} // namespace

void mlir::triton::populateMemoryOpToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit) {
RewritePatternSet &patterns, PatternBenefit benefit,
std::optional<BackendCallbacks> backendCallbacks) {
patterns.add<LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, benefit);

auto backendCall =
backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr;
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, backendCall,
benefit);
}
8 changes: 7 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
Type elemLlvmTy, ArrayRef<Value> srcVals,
Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target) {
const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount) {
bool success = emitTransferBetweenRegistersAndShared(
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase,
dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) {
Expand All @@ -418,7 +419,12 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
store(vec, vecAddr)
.setAlignment(vecTy.getNumElements() *
elemLlvmTy.getIntOrFloatBitWidth() / 8);
if (llvmOpCount) {
++(llvmOpCount->first);
llvmOpCount->second = vecTy;
}
});

if (!success)
llvm::report_fatal_error("Failed to emit transfer from register to shared");
}
Expand Down
148 changes: 148 additions & 0 deletions test/TritonGPU/amd/amd-instruction-sched.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -allocate-shared-memory -convert-scf-to-cf -convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s -check-prefix=INSTR_INSERTION
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -allocate-shared-memory -convert-scf-to-cf -convert-triton-amdgpu-to-llvm=arch=gfx942 -triton-amdgpu-lower-insert-instruction-sched-hints=variant="iglp0" | FileCheck %s -check-prefix=LOWER_IGLP0

#shared0_ex0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#mma0_ex0 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = false}>

#blocked0_ex1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1_ex1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked2_ex1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared0_ex1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
#shared1_ex1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
#mma0_ex1 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = false}>
#dot0_ex1 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma0_ex1, kWidth = 8}>
#dot1_ex1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma0_ex1, kWidth = 8}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// LOWER_IGLP0-LABEL: test_instruction_hints_lowering
tt.func @test_instruction_hints_lowering(
%arg0: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0_ex0, kWidth = 16}>>,
%arg1: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0_ex0, kWidth = 16}>>,
%arg2: tensor<32x32xf16, #mma0_ex0>) {

%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c64_i32 = arith.constant 1 : i32

scf.for %arg11 = %c0_i32 to %c64_i32 step %c1_i32 iter_args() -> () : i32 {
// LOWER_IGLP0: llvm.add
// LOWER_IGLP0-NEXT: %[[OPT_LEVEL:.*]] = llvm.mlir.constant(0 : i32) : i32
// LOWER_IGLP0-NEXT: llvm.call_intrinsic "llvm.amdgcn.iglp.opt"(%[[OPT_LEVEL]]) : (i32) -> ()
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0_ex0, kWidth = 16}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0_ex0, kWidth = 16}>> -> tensor<32x32xf16, #mma0_ex0>
scf.yield
}
tt.return
}

// INSTR_INSERTION-LABEL: @test_llvm_instruction_count
tt.func public @test_llvm_instruction_count(
%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}
) attributes {noinline = false} {

%cst = arith.constant dense<64> : tensor<256x64xi32, #blocked0_ex1>
%cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked1_ex1>
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c64_i32 = arith.constant 64 : i32
%c63_i32 = arith.constant 63 : i32
%c128_i32 = arith.constant 128 : i32
%c256_i32 = arith.constant 256 : i32

%19 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>>
%20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>>
%21 = tt.splat %c256_i32 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>>
%22 = tt.splat %c256_i32 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>>
%23 = arith.addi %21, %19 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>>
%24 = arith.addi %22, %20 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>>

%26 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>>
%27 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>>
%28 = tt.splat %c128_i32 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>>
%29 = tt.splat %c128_i32 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>>
%30 = arith.addi %28, %26 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>>
%31 = arith.addi %29, %27 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>>
%32 = tt.expand_dims %23 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0_ex1}>> -> tensor<256x1xi32, #blocked0_ex1>
%33 = tt.expand_dims %24 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2_ex1}>> -> tensor<256x1xi32, #blocked2_ex1>
%34 = tt.splat %c64_i32 : i32 -> tensor<256x1xi32, #blocked0_ex1>
%35 = arith.muli %32, %34 : tensor<256x1xi32, #blocked0_ex1>
%36 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>, #blocked0_ex1>
%37 = tt.addptr %36, %35 : tensor<256x1x!tt.ptr<f16>, #blocked0_ex1>, tensor<256x1xi32, #blocked0_ex1>
%38 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0_ex1}>>
%39 = tt.expand_dims %38 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0_ex1}>> -> tensor<1x64xi32, #blocked0_ex1>
%40 = tt.broadcast %37 : tensor<256x1x!tt.ptr<f16>, #blocked0_ex1> -> tensor<256x64x!tt.ptr<f16>, #blocked0_ex1>
%41 = tt.broadcast %39 : tensor<1x64xi32, #blocked0_ex1> -> tensor<256x64xi32, #blocked0_ex1>
%42 = tt.addptr %40, %41 : tensor<256x64x!tt.ptr<f16>, #blocked0_ex1>, tensor<256x64xi32, #blocked0_ex1>

%43 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1_ex1}>>
%44 = tt.expand_dims %43 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1_ex1}>> -> tensor<64x1xi32, #blocked1_ex1>
%45 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #blocked1_ex1>
%46 = tt.addptr %45, %44 : tensor<64x1x!tt.ptr<f16>, #blocked1_ex1>, tensor<64x1xi32, #blocked1_ex1>
%47 = tt.expand_dims %30 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1_ex1}>> -> tensor<1x128xi32, #blocked1_ex1>
%48 = tt.expand_dims %31 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2_ex1}>> -> tensor<1x128xi32, #blocked2_ex1>
%49 = tt.splat %c64_i32 : i32 -> tensor<1x128xi32, #blocked1_ex1>
%50 = arith.muli %47, %49 : tensor<1x128xi32, #blocked1_ex1>
%51 = tt.broadcast %46 : tensor<64x1x!tt.ptr<f16>, #blocked1_ex1> -> tensor<64x128x!tt.ptr<f16>, #blocked1_ex1>
%52 = tt.broadcast %50 : tensor<1x128xi32, #blocked1_ex1> -> tensor<64x128xi32, #blocked1_ex1>
%53 = tt.addptr %51, %52 : tensor<64x128x!tt.ptr<f16>, #blocked1_ex1>, tensor<64x128xi32, #blocked1_ex1>

%56 = triton_gpu.local_alloc : () -> !tt.memdesc<1x256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable>
%57 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable>

%cst_1 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma0_ex1>

%cc0_i1 = arith.constant 1 : i1
%59 = tt.splat %cc0_i1 : i1 -> tensor<256x64xi1, #blocked0_ex1>
%60 = tt.load %42, %59 : tensor<256x64x!tt.ptr<f16>, #blocked0_ex1>
%61 = tt.splat %cc0_i1 : i1 -> tensor<64x128xi1, #blocked1_ex1>
%62 = tt.load %53, %61 : tensor<64x128x!tt.ptr<f16>, #blocked1_ex1>

%63 = triton_gpu.memdesc_subview %56[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %60, %63 : tensor<256x64xf16, #blocked0_ex1> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable>
%64 = triton_gpu.memdesc_subview %57[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %62, %64 : tensor<64x128xf16, #blocked1_ex1> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable>

%66:5 = scf.for %arg11 = %c0_i32 to %c63_i32 step %c1_i32 iter_args(
%arg12 = %cst_1,
%arg13 = %42,
%arg14 = %53,
%arg16 = %63,
%arg17 = %64) -> (
tensor<256x128xf32, #mma0_ex1>,
tensor<256x64x!tt.ptr<f16>, #blocked0_ex1>,
tensor<64x128x!tt.ptr<f16>, #blocked1_ex1>,
!tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable>,
!tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable>) : i32 {

%82 = triton_gpu.local_load %arg16 : !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dot0_ex1>
%83 = triton_gpu.local_load %arg17 : !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> -> tensor<64x128xf16, #dot1_ex1>

// INSTR_INSERTION: amdgpu.instruction_sched_hint
// INSTR_INSERTION-SAME: numDsReadsA = #amdgpu.InstCounter<16, vector<8xf16>>
// INSTR_INSERTION-SAME: numDsReadsB = #amdgpu.InstCounter<8, vector<8xf16>>
// INSTR_INSERTION-SAME: numDsWritesA = #amdgpu.InstCounter<8, vector<8xf16>>
// INSTR_INSERTION-SAME: numDsWritesB = #amdgpu.InstCounter<4, vector<8xf16>>
// INSTR_INSERTION-SAME: numGlobalLoadsA = #amdgpu.InstCounter<8, vector<8xf16>>
// INSTR_INSERTION-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<8xf16>>
// INSTR_INSERTION-SAME: numMMAs = #amdgpu.InstCounter<64, tensor<32x32x8xf16>>

%84 = tt.dot %82, %83, %arg12 : tensor<256x64xf16, #dot0_ex1> * tensor<64x128xf16, #dot1_ex1> -> tensor<256x128xf32, #mma0_ex1>
%85 = tt.addptr %arg13, %cst : tensor<256x64x!tt.ptr<f16>, #blocked0_ex1>, tensor<256x64xi32, #blocked0_ex1>
%86 = tt.addptr %arg14, %cst_0 : tensor<64x128x!tt.ptr<f16>, #blocked1_ex1>, tensor<64x128xi32, #blocked1_ex1>
%87 = tt.load %85 : tensor<256x64x!tt.ptr<f16>, #blocked0_ex1>
%88 = tt.load %86 : tensor<64x128x!tt.ptr<f16>, #blocked1_ex1>
%89 = triton_gpu.memdesc_subview %56[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %87, %89 : tensor<256x64xf16, #blocked0_ex1> -> !tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable>
%90 = triton_gpu.memdesc_subview %57[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %88, %90 : tensor<64x128xf16, #blocked1_ex1> -> !tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable>

scf.yield %84, %85, %86, %89, %90 :
tensor<256x128xf32, #mma0_ex1>,
tensor<256x64x!tt.ptr<f16>, #blocked0_ex1>,
tensor<64x128x!tt.ptr<f16>, #blocked1_ex1>,
!tt.memdesc<256x64xf16, #shared0_ex1, #triton_gpu.shared_memory, mutable>,
!tt.memdesc<64x128xf16, #shared1_ex1, #triton_gpu.shared_memory, mutable>
}
tt.return
}
}
Loading

0 comments on commit e87f877

Please sign in to comment.