Skip to content

Commit

Permalink
[AMD] Attach variant to the scheduling hint op (#5808)
Browse files Browse the repository at this point in the history
This PR refactors the implementation of instruction scheduling
infrastructure. A particular "sched" variant becomes a part of the
instruction and gets added during the insertion pass. The instruction
carries this meta-information over many passes allowing to re-use the
same mechanism in some other places.
  • Loading branch information
ravil-mobile authored Feb 6, 2025
1 parent 87187d1 commit 94643b2
Show file tree
Hide file tree
Showing 13 changed files with 105 additions and 79 deletions.
10 changes: 5 additions & 5 deletions test/TritonGPU/amd/amd-instruction-sched.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm_iglp_0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm_iglp_1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=16 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local_prefetch arch=gfx942 num_stages=2' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints='variant=llvm_iglp_0' -triton-amdgpu-lower-insert-instruction-sched-hints -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints='variant=llvm_iglp_1' -triton-amdgpu-lower-insert-instruction-sched-hints -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints='variant=local_prefetch' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints='variant=local_prefetch' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=16 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints='variant=local_prefetch' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='arch=gfx942 num_stages=2' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2

Expand Down
5 changes: 2 additions & 3 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def make_ttgir(mod, metadata, options):
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, stream_prefetch)
passes.common.add_canonicalizer(pm)
if options.instruction_sched_variant.lower() != "none":
amd.passes.ttgpuir.insert_instruction_sched_hints(pm)
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.instruction_sched_variant)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
Expand Down Expand Up @@ -276,8 +276,7 @@ def make_llir(src, metadata, options):
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if options.instruction_sched_variant.lower() != "none":
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages,
options.instruction_sched_variant)
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
}];

let arguments = (ins
TritonAMDGPU_SchedHintVariantAttr:$variant,
TritonAMDGPU_InstCounter:$numDsReadsA,
TritonAMDGPU_InstCounter:$numDsReadsB,
TritonAMDGPU_InstCounter:$numDsWritesA,
Expand All @@ -143,11 +144,11 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
);

let builders = [
OpBuilder<(ins), [{
OpBuilder<(ins "amdgpu::SchedHint":$variant), [{
auto ctx = $_state.getContext();
auto noneType = NoneType::get(ctx);
auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, noneType);
build($_builder, $_state, emptyAttr, emptyAttr, emptyAttr, emptyAttr,
build($_builder, $_state, variant, emptyAttr, emptyAttr, emptyAttr, emptyAttr,
emptyAttr, emptyAttr, false, false, emptyAttr);
}]>
];
Expand Down
11 changes: 11 additions & 0 deletions third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_DIALECT_TRITONAMDGPU_UTILITY_COMMONUTILS_H_
#define TRITON_THIRD_PARTY_AMD_INCLUDE_DIALECT_TRITONAMDGPU_UTILITY_COMMONUTILS_H_

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

namespace mlir::triton::AMD {
SmallVector<scf::ForOp> getLeafForOps(triton::FuncOp funcOp);
} // namespace mlir::triton::AMD

#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_DIALECT_TRITONAMDGPU_UTILITY_COMMONUTILS_H_
5 changes: 2 additions & 3 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz);
std::unique_ptr<OperationPass<ModuleOp>>
createConvertBuiltinFuncToLLVMPass(bool ftz);
std::unique_ptr<OperationPass<ModuleOp>>
createTritonAMDGPUInsertInstructionSchedHintsPass();
createTritonAMDGPUInsertInstructionSchedHintsPass(StringRef variant);
std::unique_ptr<OperationPass<ModuleOp>>
createTritonAMDGPULowerInstructionSchedHintsPass(StringRef arch,
int32_t numStages,
StringRef variant);
int32_t numStages);

#define GEN_PASS_REGISTRATION
#include "TritonAMDGPUToLLVM/Passes.h.inc"
Expand Down
12 changes: 7 additions & 5 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,20 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul

def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Insert instruction scheduling hints after the dot ops in the main loop";
let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass()";
let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass(/*variant=*/\"\")";

let dependentDialects = ["mlir::LLVM::LLVMDialect",
"mlir::triton::amdgpu::TritonAMDGPUDialect"];

let options = [
Option<"variant", "variant", "std::string", /*default*/"\"none\"",
"instruction scheduling variant">,
];
}

def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Lower instruction scheduling hints to LLVM intrinsics";
let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*arch=*/\"\",/*numStages=*/2, /*variant=*/\"\")";
let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*arch=*/\"\",/*numStages=*/2)";

let dependentDialects = ["mlir::LLVM::LLVMDialect",
"mlir::ROCDL::ROCDLDialect",
Expand All @@ -79,10 +84,7 @@ def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-in
"gfx target device architecture, e.g., gfx942">,
Option<"numStages", "num_stages", "int32_t", /*default*/"2",
"number of pipeline stages">,
Option<"variant", "variant", "std::string", /*default*/"\"none\"",
"instruction scheduling variant">,
];
}


#endif
1 change: 1 addition & 0 deletions third_party/amd/lib/Dialect/TritonAMDGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Utility)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
add_triton_library(TritonAMDUtils
CommonUtils.cpp

LINK_LIBS PUBLIC
MLIRLLVMDialect
TritonIR
TritonGPUIR
)
17 changes: 17 additions & 0 deletions third_party/amd/lib/Dialect/TritonAMDGPU/Utility/CommonUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h"

namespace mlir::triton::AMD {
SmallVector<scf::ForOp> getLeafForOps(triton::FuncOp funcOp) {
SmallVector<scf::ForOp> allOps;
funcOp->walk([&](scf::ForOp forOp) { allOps.push_back(forOp); });

SmallVector<scf::ForOp> leafOps;
for (scf::ForOp forOp : allOps) {
auto searchResult = forOp.getBody()->walk(
[](scf::ForOp) { return WalkResult::interrupt(); });
if (!searchResult.wasInterrupted())
leafOps.push_back(forOp);
}
return leafOps;
}
} // namespace mlir::triton::AMD
82 changes: 41 additions & 41 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,25 +242,10 @@ struct InstructionSchedHintsRewriter
: public OpRewritePattern<triton::amdgpu::InstructionSchedHint> {

InstructionSchedHintsRewriter(MLIRContext *ctx, StringRef arch,
int32_t numStages, std::string variant)
int32_t numStages)
: OpRewritePattern(ctx), numStages(numStages) {

this->machineDescr = MachineDescr::get(arch);

this->schedHint = mlir::triton::amdgpu::SchedHint::none;
if (this->numStages < 2) {
LDBG("ignoring instruction scheduling due to a very low num. "
"stages value. Must be >= 2");
return;
}

std::transform(variant.begin(), variant.end(), variant.begin(),
[](unsigned char c) { return std::tolower(c); });
if (auto maybeSchedHint = triton::amdgpu::symbolizeSchedHint(variant))
this->schedHint = maybeSchedHint.value();
else
LDBG("ignoring instruction scheduling because "
"unknown instruction scheduling variant has been provided");
}

// The following is inspired by ROCm Composable Kernel library's V3 pipelining
Expand Down Expand Up @@ -422,7 +407,8 @@ struct InstructionSchedHintsRewriter
LogicalResult
matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint,
PatternRewriter &rewriter) const override {
if (this->schedHint == mlir::triton::amdgpu::SchedHint::none) {
auto schedVariant = instructionSchedHint.getVariant();
if (schedVariant == mlir::triton::amdgpu::SchedHint::none) {
rewriter.eraseOp(instructionSchedHint);
return success();
}
Expand All @@ -432,7 +418,7 @@ struct InstructionSchedHintsRewriter
// not supposed to be used together with IGLP OPT according to the AMDGPU
// backend documentation.
const bool limitSchedulingRange =
this->schedHint == mlir::triton::amdgpu::SchedHint::local_prefetch;
schedVariant == mlir::triton::amdgpu::SchedHint::local_prefetch;
;
Location loc = instructionSchedHint->getLoc();
Block *block = instructionSchedHint->getBlock();
Expand All @@ -444,10 +430,10 @@ struct InstructionSchedHintsRewriter

rewriter.setInsertionPoint(block, std::prev(block->end()));

switch (this->schedHint) {
switch (schedVariant) {
case mlir::triton::amdgpu::SchedHint::llvm_iglp_0:
case mlir::triton::amdgpu::SchedHint::llvm_iglp_1:
createIglpOpt(rewriter, loc, static_cast<int>(this->schedHint) - 1);
createIglpOpt(rewriter, loc, static_cast<int>(schedVariant) - 1);
break;
case mlir::triton::amdgpu::SchedHint::local_prefetch:
createLocalPrefetchSchedule(rewriter, loc, instructionSchedHint);
Expand All @@ -467,7 +453,6 @@ struct InstructionSchedHintsRewriter

private:
int32_t numStages;
mlir::triton::amdgpu::SchedHint schedHint;
std::unique_ptr<MachineDescr> machineDescr;
};

Expand All @@ -476,11 +461,9 @@ struct TritonAMDGPULowerInstructionSchedHints
TritonAMDGPULowerInstructionSchedHints> {

explicit TritonAMDGPULowerInstructionSchedHints(StringRef arch,
int32_t numStages,
StringRef variant) {
int32_t numStages) {
this->arch = std::move(arch.str());
this->numStages = numStages;
this->variant = std::move(variant.str());
}

void runOnOperation() override {
Expand All @@ -497,7 +480,7 @@ struct TritonAMDGPULowerInstructionSchedHints
RewritePatternSet patterns(ctx);

patterns.add<InstructionSchedHintsRewriter>(ctx, this->arch,
this->numStages, this->variant);
this->numStages);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
Expand All @@ -511,35 +494,52 @@ struct TritonAMDGPUInsertInstructionSchedHints
: public triton::impl::TritonAMDGPUInsertInstructionSchedHintsBase<
TritonAMDGPUInsertInstructionSchedHints> {

explicit TritonAMDGPUInsertInstructionSchedHints(StringRef variant) {
this->variant = std::move(variant.str());
}

void runOnOperation() override {
MLIRContext *ctx = &getContext();
ModuleOp mod = getOperation();

mod.walk([this, ctx](scf::ForOp forOp) {
// Note, instruction schedule barriers are inserted only in the case of
// a single `tt.dot` op in a `scf::ForOp` scope in the current
// implementation.
if (auto dotOp = getSingleDotOpIfExists(forOp)) {
OpBuilder rewriter(ctx);
rewriter.setInsertionPointAfter(dotOp);
rewriter.create<triton::amdgpu::InstructionSchedHint>(dotOp->getLoc());
}
});
auto schedHint = mlir::triton::amdgpu::SchedHint::none;
std::transform(variant.begin(), variant.end(), variant.begin(),
[](unsigned char c) { return std::tolower(c); });
if (auto maybeSchedHint = triton::amdgpu::symbolizeSchedHint(variant))
schedHint = maybeSchedHint.value();
else {
LDBG("ignoring instruction scheduling because "
"unknown instruction scheduling variant has been provided");
return;
}

if (schedHint != mlir::triton::amdgpu::SchedHint::none) {
mod.walk([&](scf::ForOp forOp) {
// Note, instruction schedule barriers are inserted only in the case of
// a single `tt.dot` op in a `scf::ForOp` scope in the current
// implementation.
if (auto dotOp = getSingleDotOpIfExists(forOp)) {
OpBuilder rewriter(ctx);
rewriter.setInsertionPointAfter(dotOp);
rewriter.create<triton::amdgpu::InstructionSchedHint>(dotOp->getLoc(),
schedHint);
}
});
}
}
};
} // namespace

namespace mlir::triton {
std::unique_ptr<OperationPass<ModuleOp>>
createTritonAMDGPULowerInstructionSchedHintsPass(StringRef arch,
int32_t numStages,
StringRef variant) {
return std::make_unique<TritonAMDGPULowerInstructionSchedHints>(
arch, numStages, variant);
int32_t numStages) {
return std::make_unique<TritonAMDGPULowerInstructionSchedHints>(arch,
numStages);
}

std::unique_ptr<OperationPass<ModuleOp>>
createTritonAMDGPUInsertInstructionSchedHintsPass() {
return std::make_unique<TritonAMDGPUInsertInstructionSchedHints>();
createTritonAMDGPUInsertInstructionSchedHintsPass(StringRef variant) {
return std::make_unique<TritonAMDGPUInsertInstructionSchedHints>(variant);
}
} // namespace mlir::triton
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_triton_library(TritonAMDGPUTransforms
TritonAMDGPUIR
TritonAMDGPUTransformsIncGen
TritonGPUIR
TritonAMDUtils
)

target_include_directories(TritonAMDGPUTransforms PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
Expand Down
Loading

0 comments on commit 94643b2

Please sign in to comment.