Skip to content

Commit

Permalink
Implement scaled_dot(mxfp8, fp8) via mma (#4795)
Browse files Browse the repository at this point in the history
Initial implementation using mma.

Missing to test that it plays ball with the pipeliner.
  • Loading branch information
lezcano authored Oct 12, 2024
1 parent d39ee1f commit 4daa467
Show file tree
Hide file tree
Showing 22 changed files with 719 additions and 25 deletions.
6 changes: 4 additions & 2 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,18 @@ using namespace mlir::triton;
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)

// Constants
#define int_val(bitwidth, val) \
LLVM::createLLVMIntegerConstant(rewriter, loc, bitwidth, val)
#define i1_val(val) LLVM::createConstantI1(loc, rewriter, val)
#define true_val() i1_val(true)
#define false_val() i1_val(false)
#define f16_val(...) LLVM::createConstantF16(loc, rewriter, __VA_ARGS__)
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
#define i8_val(val) int_val(8, val)
#define i16_val(val) int_val(16, val)
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
#define i64_val(...) LLVM::createConstantI64(loc, rewriter, __VA_ARGS__)
#define int_val(width, val) \
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
#define tid_val() getThreadId(rewriter, loc)

// Attributes
Expand Down
10 changes: 6 additions & 4 deletions include/triton/Dialect/Triton/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,12 @@ class DotLike : public TraitBase<ConcreteType, DotLike> {
if (aShape.size() != bShape.size() || aShape.size() != cShape.size())
return op->emitOpError("expected all operands to have the same rank");
// Check if the first two operands share a common dimension
if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2])
return op->emitOpError("expected the last dimension of the first operand "
"to be equal to the second-to-last dimension of "
"the second operand");
// TODO: enable back with an interface to support scaled dot.
// if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2])
// return op->emitOpError("expected the last dimension of the first
// operand "
// "to be equal to the second-to-last dimension of
// " "the second operand");
// Check the batch dimension
if (aShape.size() == 3 &&
(aShape[0] != cShape[0] || bShape[0] != cShape[0]))
Expand Down
14 changes: 14 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,18 @@ def TT_InputPrecisionAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}

// Type for F8F6F4 kind of floats.
def TT_F8F6F4TypeAttr : I32EnumAttr<
"F8F6F4Type", "",
[
I32EnumAttrCase<"E4M3", 0, "e4m3">,
I32EnumAttrCase<"E5M2", 1, "e5m2">,
I32EnumAttrCase<"E2M3", 2, "e2m3">,
I32EnumAttrCase<"E3M2", 3, "e3m2">,
I32EnumAttrCase<"E2M1", 4, "e2m1">

]>{
let cppNamespace = "::mlir::triton";
}

#endif
37 changes: 37 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,43 @@ def TT_DotOp : TT_Op<"dot", [Pure,
let hasVerifier = 1;
}


//
// DotScaled Op
//
def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
DotLike,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot_scaled";

let description = [{
$d = matrix_multiply(scale($lhs, $lhs_scale), scale($rhs, $rhs_scale)) + $c.
Where scale(x, s) is a function that applies the scale per block following microscaling spec.
}];

let arguments = (
ins
// inputs are integer types as they are packed types and we currently
// don't have a representation for those.
TT_IntTensor:$lhs,
TT_IntTensor:$rhs,
TT_FloatTensor:$c,
TT_IntTensor:$lhs_scale,
Optional<TT_IntTensor>:$rhs_scale,
TT_F8F6F4TypeAttr:$lhs_type,
TT_F8F6F4TypeAttr:$rhs_type
);

let results = (outs TT_FloatTensor:$d);

// Not sure why I need to fully specify the optional group, but otherwise it complains when loading the mlir file
let assemblyFormat = [{
$lhs `,` $lhs_scale `,` $rhs (`,`) : (`,` $rhs_scale^ `,`)? $c `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict
`:` type($lhs) `,` type($lhs_scale) `*` type($rhs) (`,` type($rhs_scale)^)? `->` type($d)
}];
}

//
// Reduce Op
//
Expand Down
20 changes: 20 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,24 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
}];
}

def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Convert an mxfp tensor to bf16";

let hasVerifier = 1;

let description = [{
Compute the bf16 encoded in the given mxfp number as per
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
}];
let arguments = (ins
TT_Tensor:$src,
TT_Tensor:$scale,
TT_F8F6F4TypeAttr:$fp_type);
let results = (outs TT_Tensor:$result);

let assemblyFormat = [{
$src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
}];
}

#endif
3 changes: 1 addition & 2 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class SharedEncodingAttr;
// Version = 3: <m, n, k>
SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
const ArrayRef<int64_t> &shape,
RankedTensorType type,
int numWarps);
Type type, int numWarps);

// Return true if the Load uses block pointer.
bool isLoadFromTensorPtr(triton::LoadOp op);
Expand Down
5 changes: 3 additions & 2 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::ExperimentalDescriptorStoreOp>,
GenericOpPattern<triton::ExperimentalTensormapCreateOp>,
GenericOpPattern<triton::ExperimentalTensormapFenceproxyAcquireOp>,
GenericOpPattern<triton::CallOp>, TritonFuncOpPattern>(typeConverter,
context);
// this assumes the right layout will be set later for dot scaled.
GenericOpPattern<triton::DotScaledOp>, GenericOpPattern<triton::CallOp>,
TritonFuncOpPattern>(typeConverter, context);
}

//
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_triton_library(TritonGPUIR
Dialect.cpp
LinearLayoutConversions.cpp
Ops.cpp
Types.cpp

DEPENDS
Expand Down
3 changes: 0 additions & 3 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3425,9 +3425,6 @@ void TritonGPUDialect::initialize() {
addInterfaces<TritonGPUInferLayoutInterface>();
}

#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"

// verify TritonGPU ops
LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
Expand Down
103 changes: 103 additions & 0 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include "mlir/IR/BuiltinTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h"
#include "llvm/Support/raw_ostream.h"

#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"

namespace mlir::triton::gpu {

LogicalResult UpcastMXFPOp::verify() {
auto fpType = getFpType();

auto xTy = getSrc().getType();
auto scaleTy = getScale().getType();

if (xTy.getElementType() != FloatType::getBF16(getContext())) {
return emitOpError("element type of the first operand must be bf16");
}

if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) {
return emitOpError("element type of the second operand must be uint8");
}

auto xShape = xTy.getShape();
auto scaleShape = scaleTy.getShape();

if (xShape.size() != scaleShape.size() || xShape.size() < 2) {
return emitOpError(
"operands must have the same number of dimensions, at least 2");
}

if (!(fpType == F8F6F4Type::E2M1 || fpType == F8F6F4Type::E4M3 ||
fpType == F8F6F4Type::E5M2)) {
return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2");
}

// Change to support fp8 types
const auto elems_packed = fpType == F8F6F4Type::E2M1 ? 2 : 1;

if (xShape.back() != (32 / elems_packed) * scaleShape.back()) {
return emitOpError("last dimension of first operand must be 16 times "
"larger than that of the second operand");
}

if (!std::equal(xShape.begin(), xShape.end() - 1, scaleShape.begin())) {
return emitOpError(
"all dimensions except the last must match between operands");
}

auto layoutX = xTy.getEncoding();
if (!layoutX || !isa<DotOperandEncodingAttr>(layoutX)) {
return emitOpError("Expected a DotOperandEncodingAttr for values");
}
auto layoutScale = scaleTy.getEncoding();
if (!layoutScale || !isa<BlockedEncodingAttr>(layoutScale)) {
return emitOpError("Expected a BlockOperandEncoding for scales");
}
auto blockedScale = cast<BlockedEncodingAttr>(layoutScale);

// Necessary to keep all of the scales of a given block of values in the same
// warp
auto threadsPerWarp = blockedScale.getThreadsPerWarp();
if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
return emitOpError("Expected threads per warp to be {16, 2}");
}

return success();
}

LogicalResult UpcastMXFPOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties opaqueProperties,
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
auto xTy = cast<RankedTensorType>(operands[0].getType());
auto properties = opaqueProperties.as<const Properties *>();
auto typeEncoded = properties->fp_type.getValue();
auto xShape = xTy.getShape();

auto encoding = xTy.getEncoding();
if (!encoding) {
return emitOptionalError(location, "expected an encoding");
}
if (!mlir::isa<DotOperandEncodingAttr>(encoding)) {
return emitOptionalError(location, "expected an mma layout encoding");
}
if (xShape.size() < 2) {
return emitOptionalError(location, "tensor rank must be at least 2");
}

// For now we just return the input encoding. For fp4 we'll need to cast from
// tf32 to fp16 encoding and multiply the shape by two
assert((typeEncoded == F8F6F4Type::E4M3 || typeEncoded == F8F6F4Type::E5M2) &&
"NYI: only fp8e4m3 and fp8e5m2 are supported");

inferredReturnTypes.push_back(xTy);
return success();
}

} // namespace mlir::triton::gpu
Loading

0 comments on commit 4daa467

Please sign in to comment.