Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into etiotto.raise_block_p…
Browse files Browse the repository at this point in the history
…tr.9
  • Loading branch information
etiotto committed Jan 28, 2025
2 parents 05e794c + e7e6fa3 commit d1a240b
Show file tree
Hide file tree
Showing 86 changed files with 2,540 additions and 696 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/build-test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ on:
permissions: read-all

env:
PYTHONIOENCODING: utf-8
NEW_WORKSPACE: C:\gh${{ github.run_id }}
ZE_PATH: C:\level_zero
PYTEST_MAX_PROCESSES: 8
SKIPLIST: --skip-list scripts/skiplist/a770
TRITON_TEST_CMD: bash -x scripts/test-triton.sh --skip-pytorch-install --skip-pip-install --skip-list scripts/skiplist/a770 --reports-dir reports --ignore-errors

jobs:
Expand Down Expand Up @@ -93,6 +95,14 @@ jobs:
Invoke-BatchFile "C:\Program Files (x86)\Intel\oneAPI\setvars.bat"
${{ env.TRITON_TEST_CMD }} --interpreter
- name: Run tutorials
run: |
.venv\Scripts\activate.ps1
Invoke-BatchFile "C:\Program Files (x86)\Intel\oneAPI\setvars.bat"
# Disable interactive plot window
$env:MPLBACKEND = "Agg"
${{ env.TRITON_TEST_CMD }} --tutorial
- name: Pass rate
run: |
.venv\Scripts\activate.ps1
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/pip-test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ on:
permissions: read-all

env:
PYTHONIOENCODING: utf-8
NEW_WORKSPACE: C:\gh${{ github.run_id }}
ZE_PATH: C:\level_zero
PYTEST_MAX_PROCESSES: 8
SKIPLIST: --skip-list scripts/skiplist/a770
TRITON_TEST_CMD: bash -x scripts/test-triton.sh --skip-pytorch-install --skip-pip-install --skip-list scripts/skiplist/a770 --reports-dir reports --ignore-errors

jobs:
Expand Down Expand Up @@ -110,6 +112,14 @@ jobs:
Invoke-BatchFile "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvarsall.bat" x64
${{ env.TRITON_TEST_CMD }} --interpreter
- name: Run tutorials
run: |
.venv\Scripts\activate.ps1
Invoke-BatchFile "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvarsall.bat" x64
# Disable interactive plot window
$env:MPLBACKEND = "Agg"
${{ env.TRITON_TEST_CMD }} --tutorial
- name: Pass rate
run: |
.venv\Scripts\activate.ps1
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/triton-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ jobs:
mv $REPORTS/attn-performance.csv $REPORTS/attn-bwd-performance.csv
source ../../scripts/capture-hw-details.sh
python ../../scripts/build_report.py $REPORTS/attn-bwd-performance.csv $REPORTS/attn-bwd-triton-report.csv --benchmark attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL,MODE" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
python ../../scripts/build_report.py $REPORTS/attn-bwd-performance.csv $REPORTS/attn-bwd-xetla-report.csv --benchmark attn --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL,MODE" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
python ../../scripts/build_report.py $REPORTS/attn-bwd-performance.csv $REPORTS/attn-bwd-triton-report.csv --benchmark attn-bwd --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL,MODE" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
python ../../scripts/build_report.py $REPORTS/attn-bwd-performance.csv $REPORTS/attn-bwd-xetla-report.csv --benchmark attn-bwd --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL,MODE" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
- name: Run Prefix Sums kernel benchmark
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'prefix_sums.py') }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/wheels_v2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
docker container prune -f
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4

# The LATEST_DATE here should be kept in sync with the one in Patch setup.py
- id: check-version
Expand Down
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2fe947b47798de1ad20553be4e162e332428ad91
e2402615a5a76d46a433dfcc1de10b38a1263c9d
8 changes: 4 additions & 4 deletions include/triton/Conversion/MLIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ inline Type u1Ty(MLIRContext *ctx) {
}

// Float types
inline Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
inline Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
inline Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
inline Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
inline Type f16Ty(MLIRContext *ctx) { return Float16Type::get(ctx); }
inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); }
inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); }
inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); }

inline bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
Expand Down
6 changes: 6 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,12 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);

[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);

SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
triton::gpu::MemDescType srcTy,
Type elemLlvmTy,
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI

/// Set the callee for this operation.
void setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
(*this)->setAttr("callee", cast<SymbolRefAttr>(callee));
}

// Required by CallOpInterface.
Expand Down
4 changes: 2 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@ unsigned getNumCTAs(Attribute layout);
// len(shape) == rank.
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);

// Return the order that represents that the dot operand is in kMajor
// Return the order that represents that the dot operand is in kContig
// (contiguous in the inner dimension) or it's contiguous on the outer
// dimension.
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor);
bool kContig);

bool isExpensiveCat(CatOp cat, Attribute targetEncoding);

Expand Down
4 changes: 2 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,

// The primary goal of this function is to efficiently store 2D tiles of a
// tensor into shared memory using the `ldmatrix` instruction.
LinearLayout chooseLdMatrixLayout(MLIRContext *ctx, Attribute sharedEnc,
Attribute dotEnc, ArrayRef<int64_t> shape);
LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
bool needTrans, int32_t elemBitWidth);
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
16 changes: 16 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,22 @@ def TritonGPUTestPipelineScheduleLoop : Pass<"tritongpu-test-pipeline-schedule-l
"mlir::arith::ArithDialect"];
}

def TritonGPUFuseNestedLoops : Pass<"tritongpu-fuse-nested-loops", "mlir::ModuleOp"> {
let summary = "fuse nested loops for pipelining";

let description = [{
The `tritongpu-fuse-nested-loops` pass will analyze loop nests in the module
that need to be pipelined and fuse them into a single loop. This composes
with the pipeliner to pipeline loop nests.
}];

let dependentDialects = [
"mlir::triton::gpu::TritonGPUDialect",
"mlir::arith::ArithDialect",
"mlir::ub::UBDialect",
];
}

def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> {
let summary = "3xTF32 trick";

Expand Down
5 changes: 0 additions & 5 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,6 @@ enum class MMALoadType {
};
MMALoadType getMMALoadType(Operation *loadOp);

// Returns composed LinearLayout for register to shared copy
triton::LinearLayout getRegToSharedLayout(MLIRContext *ctx,
ArrayRef<int64_t> shape,
Attribute srcEnc, Attribute dstEnc,
int elemBitWidth);
} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global",
}];
}

def TTNG_TMAStoreWait : TTNG_Op<"async_tma_store_wait"> {
def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {
let summary = "wait until all the inputs are read.";
let arguments = (ins I32Attr:$pendings);
let description = [{
Expand Down
30 changes: 9 additions & 21 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

Expand Down Expand Up @@ -232,13 +231,13 @@ class MakeRangeOpAxisInfoVisitor final
}
};

template <typename OpTy>
class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
class ConstantOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<arith::ConstantOp> {
public:
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
using AxisInfoVisitorImpl::AxisInfoVisitorImpl;

AxisInfo
getAxisInfo(OpTy op,
getAxisInfo(arith::ConstantOp op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto intAttr = dyn_cast<IntegerAttr>(op.getValue());
auto boolAttr = dyn_cast<BoolAttr>(op.getValue());
Expand Down Expand Up @@ -323,8 +322,7 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value()) {
if constexpr (std::is_same_v<OpTy, arith::AddIOp> ||
std::is_same_v<OpTy, LLVM::AddOp>) {
if constexpr (std::is_same_v<OpTy, arith::AddIOp>) {
return {lhs.getConstantValue().value() +
rhs.getConstantValue().value()};
} else if constexpr (std::is_same_v<OpTy, arith::SubIOp>) {
Expand Down Expand Up @@ -1013,15 +1011,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
CastOpAxisInfoVisitor<triton::BitcastOp>>();
// TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp
// when scf.for supports integer induction variables
visitors.append<MakeRangeOpAxisInfoVisitor>();
visitors.append<ConstantOpAxisInfoVisitor<arith::ConstantOp>,
ConstantOpAxisInfoVisitor<LLVM::ConstantOp>>();
visitors.append<ConstantOpAxisInfoVisitor>();
visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
AddSubOpAxisInfoVisitor<arith::AddIOp>,
AddSubOpAxisInfoVisitor<arith::SubIOp>,
AddSubOpAxisInfoVisitor<LLVM::AddOp>>();
AddSubOpAxisInfoVisitor<arith::SubIOp>>();
visitors.append<MulIOpAxisInfoVisitor>();
visitors.append<DivOpAxisInfoVisitor<arith::DivSIOp>,
DivOpAxisInfoVisitor<arith::DivUIOp>>();
Expand Down Expand Up @@ -1138,17 +1132,11 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,

if (blockArg && blockArg.getOwner()->isEntryBlock()) {
Operation *op = blockArg.getOwner()->getParentOp();
if (auto fun = dyn_cast<FunctionOpInterface>(op))
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
&knownContiguity, &knownDivisibility,
&knownConstancy);
// llvm codegen check alignment to generate vector load/store
// would be nice if this wasn't the case
else if (auto fun = dyn_cast<LLVM::LLVMFuncOp>(op))
if (auto fun = dyn_cast<FunctionOpInterface>(op)) {
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
&knownContiguity, &knownDivisibility,
&knownConstancy);
else if (isa<RegionBranchOpInterface>(op)) {
} else if (isa<RegionBranchOpInterface>(op)) {
// scf::ForOp, scf::IfOp, scf::WhileOp
// Control flow operations are initialized with "unknown" state:
// the maximum possible divisibility, contiguity, and constancy.
Expand Down
28 changes: 14 additions & 14 deletions lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
add_triton_library(TritonGPUToLLVM
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
DotOpToLLVM/FMA.cpp
GlobalScratchMemoryAllocation.cpp
TypeConverter.cpp
Utility.cpp
ElementwiseOpToLLVM.cpp
MemoryOpToLLVM.cpp
AllocateSharedMemory.cpp
AssertOpToLLVM.cpp
ViewOpToLLVM.cpp
MakeRangeOpToLLVM.cpp
ControlFlowOpToLLVM.cpp
ConvertLayoutOpToLLVM.cpp
DecomposeUnsupportedConversions.cpp
ElementwiseOpToLLVM.cpp
FuncOpToLLVM.cpp
GatherOpToLLVM.cpp
GlobalScratchMemoryAllocation.cpp
HistogramOpToLLVM.cpp
AllocateSharedMemory.cpp
MakeRangeOpToLLVM.cpp
MemoryOpToLLVM.cpp
PrintOpToLLVM.cpp
ReduceOpToLLVM.cpp
ScanOpToLLVM.cpp
GatherOpToLLVM.cpp
ConvertLayoutOpToLLVM.cpp
ControlFlowOpToLLVM.cpp
FuncOpToLLVM.cpp
SPMDOpToLLVM.cpp
DecomposeUnsupportedConversions.cpp
PrintOpToLLVM.cpp
TypeConverter.cpp
Utility.cpp
ViewOpToLLVM.cpp

DEPENDS
TritonGPUConversionPassIncGen
Expand Down
Loading

0 comments on commit d1a240b

Please sign in to comment.