Skip to content

Commit

Permalink
Merge commit 'aa833c9d8c3df4dd10589758675342ae82371457'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Jan 15, 2025
2 parents 62e87e8 + aa833c9 commit e9b11eb
Show file tree
Hide file tree
Showing 42 changed files with 801 additions and 405 deletions.
10 changes: 4 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,25 @@ repos:
- id: debug-statements

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.1
rev: v0.9.1
hooks:
- id: ruff
files: '(^python|^third_party/proton|^third_party/amd|^benchmarks|^third_party/intel|^scripts)/.*'
files: '(^python|^third_party/proton|^third_party/amd|^third_party/nvidia|^test|^benchmarks|^third_party/intel|^scripts)/.*'
args: ["--fix", "--exit-non-zero-on-fix"]
exclude: |
(?x)(
^python/triton/runtime/.*|
^test/|
^docs/conf.py$
)
- repo: https://github.com/google/yapf
rev: "7e21823"
rev: "v0.43.0"
hooks:
- id: yapf
args: ["-p", "-i"]
exclude: "python/test/unit/language/test_line_info.py"

- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.2
rev: v19.1.6
hooks:
- id: clang-format

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass.
- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass.
- `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma).
- `MLIR_ENABLE_DIAGNOSTICS` enables dumping the stack trace and the related IR operation of diagnostics (e.g., errors and warnings).
- `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks.
- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx/amdgcn.
- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx/amdgcn when `TRITON_KERNEL_DUMP` is set to 1.
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class Allocation {
};

/// Op -> Scratch Buffer
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
using OpScratchMapT = llvm::MapVector<Operation *, BufferT *>;
/// Value -> Explicit Buffer
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
/// Value -> Alias Buffer
Expand Down
16 changes: 11 additions & 5 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,6 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
// MXFP utilities
// -----------------------------------------------------------------------

// Convert each value, which is an int8 containing 2 packed mxfp4 values,
// into 2 standalone bf16 values
SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
ArrayRef<Value> values);

// Scale a mxfp4 value by a given scale.
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale,
bool fastMath);
Expand Down Expand Up @@ -1107,6 +1102,17 @@ inline Value packLLVector(Location loc, ValueRange vals,
return vec;
}

inline bool
isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
ArrayRef<int64_t> allocShape,
triton::gpu::SharedEncodingAttr sharedEnc) {
auto rank = shape.size();
return /*no swizzling*/ sharedEnc.getMaxPhase() == 1 ||
/*swizzling but same shape*/ shape == allocShape ||
/*swizzling and rank-reduced and rank >= 2*/
(shape == allocShape.take_back(rank) && rank >= 2);
}

} // namespace mlir

#endif
8 changes: 5 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,12 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
// bit width of the tensor in the future to support more flexible tensor
// encodings
LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order,
int swizzleByteSize);

// The primary goal of this function is to efficiently store 2D tiles of a
// tensor into shared memory using the `ldmatrix` instruction.
LinearLayout chooseLdMatrixLayout(MLIRContext *ctx, Attribute sharedEnc,
Attribute dotEnc, ArrayRef<int64_t> shape);
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
7 changes: 6 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -567,14 +567,19 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
// Linear Layout Encoding
//===----------------------------------------------------------------------===//

def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",
"linear layout"> {
let cppAccessorType = "const LinearLayout &";
}

def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
let mnemonic = "linear";

let description = [{
See the docs in LinearLayout.h for the definition of linear layouts.
}];

let parameters = (ins "LinearLayout":$linearLayout);
let parameters = (ins LinearLayoutParam:$linearLayout);

let extraClassDeclaration = extraDistributedDeclaration # [{
SmallVector<unsigned> getContigPerThread() const;
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -725,4 +725,4 @@ inline std::ostream &operator<<(std::ostream &os, const LinearLayout &layout) {

} // namespace mlir::triton

#endif
#endif // TRITON_TOOLS_LINEARLAYOUT_H
51 changes: 51 additions & 0 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE "allocation-shared-memory"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

namespace mlir {

Expand Down Expand Up @@ -297,6 +303,10 @@ class AllocationAnalysis {
auto value = valueBufferIter.first;
auto *buffer = valueBufferIter.second;
bufferRange[buffer] = getLiveness(value);
LLVM_DEBUG({
llvm::dbgs() << "-- buffer " << buffer->id << "; value: ";
value.dump();
});
}
}

Expand Down Expand Up @@ -336,6 +346,10 @@ class AllocationAnalysis {
auto *buffer = opScratchIter.second;
bufferRange.insert({buffer, Interval(operationId.lookup(op),
operationId.lookup(op) + 1)});
LLVM_DEBUG({
llvm::dbgs() << "-- buffer " << buffer->id << "; value: ";
op->dump();
});
}
};
processScratchMemory(allocation->opScratch);
Expand Down Expand Up @@ -387,6 +401,28 @@ class AllocationAnalysis {
resolveScratchBufferLiveness(operationId);
}

void dumpBuffers() {
LDBG("Dump bufferRange: id size offset ---------");
for (auto bufferIter : bufferRange) {
llvm::dbgs() << "-- " << bufferIter.first->id << " "
<< bufferIter.first->size << " " << bufferIter.first->offset;
llvm::dbgs() << " interval " << bufferIter.second.start() << " "
<< bufferIter.second.end() << "\n";
}
}

void dumpInterferenceGraph(const GraphT &interference) {
LDBG("\n");
LDBG("Dump interference graph: \n");
for (auto edges : interference) {
llvm::dbgs() << "-- from " << edges.first->id << " to ";
for (auto node : edges.second) {
llvm::dbgs() << node->id << "; ";
}
llvm::dbgs() << "\n";
}
}

/// Computes the shared memory offsets for all related values.
/// Paper: Algorithms for Compile-Time Memory Optimization
/// (https://dl.acm.org/doi/pdf/10.5555/314500.315082)
Expand All @@ -396,6 +432,14 @@ class AllocationAnalysis {
buffers.emplace_back(bufferIter.first);
}

// Sort buffers by size in descending order to reduce the fragmentation
// on big buffers caused by smaller buffers. Big buffers have a higher
// chance to overlap with multiple other buffers, and allocating them first
// (by calculateStarts) ensures a higher chance that they will occupy a
// standalone smem slot.
llvm::stable_sort(
buffers, [&](BufferT *A, BufferT *B) { return A->size > B->size; });

calculateStarts(buffers);

// NOTE: The original paper doesn't consider interference between
Expand Down Expand Up @@ -471,6 +515,7 @@ class AllocationAnalysis {
xBuffers.erase(bufferIt);
}
}
LLVM_DEBUG(dumpBuffers());
}

/// Builds a graph of all shared memory values. Edges are created between
Expand All @@ -497,6 +542,8 @@ class AllocationAnalysis {
}
}
}

LLVM_DEBUG(dumpInterferenceGraph(interference));
}

/// Finalizes shared memory offsets considering interference.
Expand Down Expand Up @@ -524,6 +571,9 @@ class AllocationAnalysis {
}
auto it = std::find(available.begin(), available.end(), true);
colors[x] = std::distance(available.begin(), it);
LLVM_DEBUG({
llvm::dbgs() << "-- color " << x->id << " " << colors[x] << "\n";
});
}
// Finalize allocation
// color0: [0, 7), [0, 8), [0, 15) -> [0, 7), [0, 8), [0, 15)
Expand All @@ -541,6 +591,7 @@ class AllocationAnalysis {
allocation->sharedMemorySize =
std::max(allocation->sharedMemorySize, x->offset + x->size);
}
LLVM_DEBUG(dumpBuffers());
}

private:
Expand Down
1 change: 0 additions & 1 deletion lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,6 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
visitors.append<CastOpAxisInfoVisitor<arith::ExtSIOp>,
CastOpAxisInfoVisitor<arith::ExtUIOp>,
CastOpAxisInfoVisitor<arith::TruncIOp>,
CastOpAxisInfoVisitor<arith::IndexCastOp>,
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
CastOpAxisInfoVisitor<triton::BitcastOp>>();
Expand Down
6 changes: 2 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0);
LinearLayout shmemStoreLayout =
isStMatrix ? chooseStMatrixLayout(
ctx, op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0)
isStMatrix ? chooseStMatrixLayout(ctx, op.getSrc().getType(),
/*swizzleByteSize=*/0)
: srcLayout.invertAndCompose(sharedLayout);

const int shmemAllocatedNumElems =
Expand Down
30 changes: 0 additions & 30 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,35 +458,6 @@ struct AbsFOpConversion
return {rewriter.create<LLVM::FAbsOp>(loc, elemTy, operands[0][0])};
}
};
/// The lowering of index_cast becomes an integer conversion since index
/// becomes an integer. If the bit width of the source and target integer
/// types is the same, just erase the cast. If the target type is wider,
/// sign-extend the value, otherwise truncate it.
struct IndexCastOpLowering
: public ElementwiseOpConversionBase<arith::IndexCastOp,
IndexCastOpLowering> {
using Base =
ElementwiseOpConversionBase<arith::IndexCastOp, IndexCastOpLowering>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

SmallVector<Value> createDestOps(arith::IndexCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
auto inElemTy =
this->getTypeConverter()->convertType(getElementType(op.getIn()));
unsigned targetBits = elemTy.getIntOrFloatBitWidth();
unsigned sourceBits = inElemTy.getIntOrFloatBitWidth();

if (targetBits == sourceBits)
return {operands[0][0]};
if (targetBits < sourceBits)
return {
rewriter.create<LLVM::TruncOp>(op.getLoc(), elemTy, operands[0][0])};
return {rewriter.create<LLVM::SExtOp>(op.getLoc(), elemTy, operands[0][0])};
}
};

struct SelectOpConversion
: ElementwiseOpConversionBase<arith::SelectOp, SelectOpConversion> {
Expand Down Expand Up @@ -705,6 +676,5 @@ void mlir::triton::populateElementwiseOpToLLVMPatterns(
patterns.add<ElementwiseInlineAsmOpConversion>(typeConverter, benefit);
patterns.add<AbsIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<AbsFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<IndexCastOpLowering>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<SelectOpConversion>(typeConverter, axisInfoAnalysis, benefit);
}
38 changes: 1 addition & 37 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,7 @@ Value getSmemVecAddr(RankedTensorType registerTy,
// We propose case 2 (see comments below), which provides a more general
// solution for all swizzled shared memory scenarios, including the edge case
// mentioned above.
if (/*no swizzling*/ sharedEnc.getMaxPhase() == 1 ||
/*swizzling but same shape*/ shape == allocShape ||
/*swizzling and rank-reduced and rank >= 2*/
(shape == allocShape.take_back(rank) && rank >= 2)) { // Case 1
if (isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1
// Get the address to load/store. The multi-dim address is (offsetX1, ...,
// offsetXN, block), where the offsets appear in minor-to-major order, and
// we drop_end to drop block, which we know from above will be 0.
Expand Down Expand Up @@ -871,39 +868,6 @@ SmallVector<Value> getWrappedMultiDimOffset(
return multiDimOffsetWrapped;
}

SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
ArrayRef<Value> values) {
SmallVector<Value> results;
for (auto v : values) {
auto em0 = and_(v, i8_val(0x7));
auto em1 = and_(v, i8_val(0x70));
Value v0 = or_(shl(zext(i16_ty, em0), i16_val(6)),
shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12)));
Value v1 = or_(shl(zext(i16_ty, em1), i16_val(2)),
shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8)));

// Three cases:
// 1) x is normal and non-zero: Correct bias
v0 = select(icmp_ne(and_(em0, i8_val(0x6)), i8_val(0)),
add(v0, i16_val((127 - 1) << 7)), v0);
v1 = select(icmp_ne(and_(em1, i8_val(0x60)), i8_val(0)),
add(v1, i16_val((127 - 1) << 7)), v1);

// 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in
// bf16
v0 = bitcast(select(icmp_eq(em0, i8_val(0x1)),
or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0),
bf16_ty);
v1 = bitcast(select(icmp_eq(em1, i8_val(0x10)),
or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1),
bf16_ty);
// 3) x is zero, nothing to do
results.push_back(v0);
results.push_back(v1);
}
return results;
}

Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale,
bool fastMath) {
Value vBf16 = bitcast(v, bf16_ty);
Expand Down
Loading

0 comments on commit e9b11eb

Please sign in to comment.