diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index e47a025fbbf1..29b8865c03ae 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -125,16 +125,18 @@ using namespace mlir::triton; #define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count) // Constants +#define int_val(bitwidth, val) \ + LLVM::createLLVMIntegerConstant(rewriter, loc, bitwidth, val) #define i1_val(val) LLVM::createConstantI1(loc, rewriter, val) #define true_val() i1_val(true) #define false_val() i1_val(false) #define f16_val(...) LLVM::createConstantF16(loc, rewriter, __VA_ARGS__) #define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__) #define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__) +#define i8_val(val) int_val(8, val) +#define i16_val(val) int_val(16, val) #define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) #define i64_val(...) LLVM::createConstantI64(loc, rewriter, __VA_ARGS__) -#define int_val(width, val) \ - LLVM::createLLVMIntegerConstant(rewriter, loc, width, val) #define tid_val() getThreadId(rewriter, loc) // Attributes diff --git a/include/triton/Dialect/Triton/IR/Traits.h b/include/triton/Dialect/Triton/IR/Traits.h index e47c023a29b0..7f0e5109e6b9 100644 --- a/include/triton/Dialect/Triton/IR/Traits.h +++ b/include/triton/Dialect/Triton/IR/Traits.h @@ -81,10 +81,12 @@ class DotLike : public TraitBase { if (aShape.size() != bShape.size() || aShape.size() != cShape.size()) return op->emitOpError("expected all operands to have the same rank"); // Check if the first two operands share a common dimension - if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2]) - return op->emitOpError("expected the last dimension of the first operand " - "to be equal to the second-to-last dimension of " - "the second operand"); + // TODO: enable back with an interface to support scaled dot. + // if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2]) + // return op->emitOpError("expected the last dimension of the first + // operand " + // "to be equal to the second-to-last dimension of + // " "the second operand"); // Check the batch dimension if (aShape.size() == 3 && (aShape[0] != cShape[0] || bShape[0] != cShape[0])) diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index e0d8c7ce35cb..f3159338bd0a 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -119,4 +119,18 @@ def TT_InputPrecisionAttr : I32EnumAttr< let cppNamespace = "::mlir::triton"; } +// Type for F8F6F4 kind of floats. +def TT_F8F6F4TypeAttr : I32EnumAttr< + "F8F6F4Type", "", + [ + I32EnumAttrCase<"E4M3", 0, "e4m3">, + I32EnumAttrCase<"E5M2", 1, "e5m2">, + I32EnumAttrCase<"E2M3", 2, "e2m3">, + I32EnumAttrCase<"E3M2", 3, "e3m2">, + I32EnumAttrCase<"E2M1", 4, "e2m1"> + + ]>{ + let cppNamespace = "::mlir::triton"; +} + #endif diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 756d22f9dbef..66946c20cc96 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -673,6 +673,43 @@ def TT_DotOp : TT_Op<"dot", [Pure, let hasVerifier = 1; } + +// +// DotScaled Op +// +def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, + DotLike, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot_scaled"; + + let description = [{ + $d = matrix_multiply(scale($lhs, $lhs_scale), scale($rhs, $rhs_scale)) + $c. + Where scale(x, s) is a function that applies the scale per block following microscaling spec. + }]; + + let arguments = ( + ins + // inputs are integer types as they are packed types and we currently + // don't have a representation for those. + TT_IntTensor:$lhs, + TT_IntTensor:$rhs, + TT_FloatTensor:$c, + TT_IntTensor:$lhs_scale, + Optional:$rhs_scale, + TT_F8F6F4TypeAttr:$lhs_type, + TT_F8F6F4TypeAttr:$rhs_type + ); + + let results = (outs TT_FloatTensor:$d); + + // Not sure why I need to fully specify the optional group, but otherwise it complains when loading the mlir file + let assemblyFormat = [{ + $lhs `,` $lhs_scale `,` $rhs (`,`) : (`,` $rhs_scale^ `,`)? $c `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict + `:` type($lhs) `,` type($lhs_scale) `*` type($rhs) (`,` type($rhs_scale)^)? `->` type($d) + }]; +} + // // Reduce Op // diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 1361085bd5a9..a290cb20310a 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -256,4 +256,24 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods]> { + let summary = "Convert an mxfp tensor to bf16"; + + let hasVerifier = 1; + + let description = [{ + Compute the bf16 encoded in the given mxfp number as per + https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + }]; + let arguments = (ins + TT_Tensor:$src, + TT_Tensor:$scale, + TT_F8F6F4TypeAttr:$fp_type); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result) + }]; +} + #endif diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 6e257dbf733b..e688b52245ee 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -28,8 +28,7 @@ class SharedEncodingAttr; // Version = 3: SmallVector mmaVersionToInstrShape(int version, const ArrayRef &shape, - RankedTensorType type, - int numWarps); + Type type, int numWarps); // Return true if the Load uses block pointer. bool isLoadFromTensorPtr(triton::LoadOp op); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 948f92603f46..bd17e2d7c8b2 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -553,8 +553,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, GenericOpPattern, - GenericOpPattern, TritonFuncOpPattern>(typeConverter, - context); + // this assumes the right layout will be set later for dot scaled. + GenericOpPattern, GenericOpPattern, + TritonFuncOpPattern>(typeConverter, context); } // diff --git a/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/lib/Dialect/TritonGPU/IR/CMakeLists.txt index b5dcdb5ea5b8..98831f0db8ac 100644 --- a/lib/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TritonGPUIR Dialect.cpp LinearLayoutConversions.cpp + Ops.cpp Types.cpp DEPENDS diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 728d8966aa4d..60bfd56cb00f 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -3425,9 +3425,6 @@ void TritonGPUDialect::initialize() { addInterfaces(); } -#define GET_OP_CLASSES -#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" - // verify TritonGPU ops LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp new file mode 100644 index 000000000000..b9f3d3040dd3 --- /dev/null +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -0,0 +1,103 @@ +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h" +#include "llvm/Support/raw_ostream.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" + +namespace mlir::triton::gpu { + +LogicalResult UpcastMXFPOp::verify() { + auto fpType = getFpType(); + + auto xTy = getSrc().getType(); + auto scaleTy = getScale().getType(); + + if (xTy.getElementType() != FloatType::getBF16(getContext())) { + return emitOpError("element type of the first operand must be bf16"); + } + + if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) { + return emitOpError("element type of the second operand must be uint8"); + } + + auto xShape = xTy.getShape(); + auto scaleShape = scaleTy.getShape(); + + if (xShape.size() != scaleShape.size() || xShape.size() < 2) { + return emitOpError( + "operands must have the same number of dimensions, at least 2"); + } + + if (!(fpType == F8F6F4Type::E2M1 || fpType == F8F6F4Type::E4M3 || + fpType == F8F6F4Type::E5M2)) { + return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2"); + } + + // Change to support fp8 types + const auto elems_packed = fpType == F8F6F4Type::E2M1 ? 2 : 1; + + if (xShape.back() != (32 / elems_packed) * scaleShape.back()) { + return emitOpError("last dimension of first operand must be 16 times " + "larger than that of the second operand"); + } + + if (!std::equal(xShape.begin(), xShape.end() - 1, scaleShape.begin())) { + return emitOpError( + "all dimensions except the last must match between operands"); + } + + auto layoutX = xTy.getEncoding(); + if (!layoutX || !isa(layoutX)) { + return emitOpError("Expected a DotOperandEncodingAttr for values"); + } + auto layoutScale = scaleTy.getEncoding(); + if (!layoutScale || !isa(layoutScale)) { + return emitOpError("Expected a BlockOperandEncoding for scales"); + } + auto blockedScale = cast(layoutScale); + + // Necessary to keep all of the scales of a given block of values in the same + // warp + auto threadsPerWarp = blockedScale.getThreadsPerWarp(); + if (threadsPerWarp != ArrayRef({16, 2})) { + return emitOpError("Expected threads per warp to be {16, 2}"); + } + + return success(); +} + +LogicalResult UpcastMXFPOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties opaqueProperties, + RegionRange regions, SmallVectorImpl &inferredReturnTypes) { + auto xTy = cast(operands[0].getType()); + auto properties = opaqueProperties.as(); + auto typeEncoded = properties->fp_type.getValue(); + auto xShape = xTy.getShape(); + + auto encoding = xTy.getEncoding(); + if (!encoding) { + return emitOptionalError(location, "expected an encoding"); + } + if (!mlir::isa(encoding)) { + return emitOptionalError(location, "expected an mma layout encoding"); + } + if (xShape.size() < 2) { + return emitOptionalError(location, "tensor rank must be at least 2"); + } + + // For now we just return the input encoding. For fp4 we'll need to cast from + // tf32 to fp16 encoding and multiply the shape by two + assert((typeEncoded == F8F6F4Type::E4M3 || typeEncoded == F8F6F4Type::E5M2) && + "NYI: only fp8e4m3 and fp8e5m2 are supported"); + + inferredReturnTypes.push_back(xTy); + return success(); +} + +} // namespace mlir::triton::gpu diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index d9bbd51bd9a1..08a88ae397a7 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -1,15 +1,18 @@ -#include - +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include "triton/Tools/Sys/GetEnv.hpp" -#include "llvm/Support/Debug.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" namespace mlir { namespace triton { @@ -242,8 +245,9 @@ class BlockedToMMA : public mlir::OpRewritePattern { if (!(versionMajor >= 1 && versionMajor <= 3)) return failure(); - auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA, - dotOp.getA().getType(), numWarps); + auto instrShape = mmaVersionToInstrShape( + versionMajor, retShapePerCTA, dotOp.getA().getType().getElementType(), + numWarps); // operands Value a = dotOp.getA(); Value b = dotOp.getB(); @@ -380,6 +384,140 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { }); } +class ScaledBlockedToMMAv2 + : public mlir::OpRewritePattern { + int computeCapability; + +public: + ScaledBlockedToMMAv2(mlir::MLIRContext *context, int computeCapability) + : mlir::OpRewritePattern(context), + computeCapability(computeCapability) {} + + mlir::LogicalResult + matchAndRewrite(triton::DotScaledOp dotOp, + mlir::PatternRewriter &rewriter) const override { + if (computeCapability >= 100) + return failure(); + + auto oldRetType = dotOp.getType(); + if (!oldRetType.getEncoding() || + mlir::isa(oldRetType.getEncoding())) + return failure(); + auto ctx = dotOp.getContext(); + + // Check that rhs scale is null + assert(dotOp.getRhsScale() == nullptr && "rhs scale must be null"); + + // operands + auto a = dotOp.getLhs(); + auto b = dotOp.getRhs(); + auto scale = dotOp.getLhsScale(); + auto aType = dotOp.getLhsType(); + auto bType = dotOp.getRhsType(); + + auto enumToType = [&rewriter](F8F6F4Type type) { + switch (type) { + case F8F6F4Type::E4M3: + return rewriter.getFloat8E4M3FNType(); + case F8F6F4Type::E5M2: + return rewriter.getFloat8E5M2Type(); + default: + llvm_unreachable("unexpected type"); + } + }; + + assert(aType == F8F6F4Type::E4M3 || + aType == F8F6F4Type::E5M2 && "lhs just supports fp8"); + assert(bType == F8F6F4Type::E4M3 || + bType == F8F6F4Type::E5M2 && "rhs just supports fp8"); + + // TODO run accelerate matmul on A and B first to choose their layouts + // Set return type + auto versionMajor = 2; + auto retShapePerCTA = getShapePerCTA(oldRetType); + auto mod = dotOp->getParentOfType(); + unsigned numWarps = TritonGPUDialect::getNumWarps(mod); + auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA, + rewriter.getBF16Type(), numWarps); + auto CTALayout = getCTALayout(oldRetType.getEncoding()); + SmallVector warpsPerCTA = {numWarps, 1}; + auto mmaEnc = NvidiaMmaEncodingAttr::get(ctx, /*versionMajor=*/versionMajor, + /*versionMinor=*/0, warpsPerCTA, + CTALayout, instrShape); + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); + + // convert accumulator + auto oldAcc = dotOp.getOperand(2); + auto newAcc = + rewriter.create(oldAcc.getLoc(), newRetType, oldAcc); + + auto toMMABf16 = [&newRetType, &rewriter, &ctx, + &enumToType](TypedValue v, int idx, + F8F6F4Type type) { + // MMAv2 Layout + auto vType = v.getType(); + auto newVEncoding = DotOperandEncodingAttr::get( + ctx, idx, newRetType.getEncoding(), enumToType((type))); + auto newVType = RankedTensorType::get( + v.getType().getShape(), v.getType().getElementType(), newVEncoding); + v = rewriter.create(v.getLoc(), newVType, v); + + // Bitcast + auto vTypeFp8 = RankedTensorType::get( + vType.getShape(), rewriter.getFloat8E4M3FNType(), newVEncoding); + v = cast>( + rewriter.create(v.getLoc(), vTypeFp8, v).getResult()); + + // Convert to bf16 + auto vTypeBf16 = RankedTensorType::get( + vType.getShape(), rewriter.getBF16Type(), newVEncoding); + return rewriter.create(v.getLoc(), vTypeBf16, v); + }; + a = toMMABf16(a, 0, aType); + b = toMMABf16(b, 1, bType); + + // [Note: A trick to avoid warp shuffles in the lowering] + // FIXME: Implement this when we can set general layouts on a tensor + + // For bf16, we have 4 threads per row + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-a-f16 + // and each of them needs to get every scale in that row. + // It turns out that the layout for the output of type bf16 gives us exactly + // this layout when the number of mxfp vectors is equal to two (K = 64) + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-c + // This can be generalised to other K with linear layouts, but the general + // layout cannot cannot be represented with the predefined layouts :( + // With this trick, we could do the full lowering here and remove the + // UpcastMXFPOp altogether + + assert(instrShape == ArrayRef({16, 8}) || + instrShape == ArrayRef({1, 16, 8})); + auto shapeTileA = std::array{instrShape[0], instrShape[0]}; + // Necessary choice to leave all the scales of the tile in that given warp + auto threadsPerWarp = + SmallVector{shapeTileA[0], 32 / shapeTileA[0]}; + + auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( + ctx, {1, 1}, threadsPerWarp, warpsPerCTA, {1, 0}, CTALayout); + + auto newScaleType = RankedTensorType::get(scale.getType().getShape(), + scale.getType().getElementType(), + newScaleEncoding); + scale = + rewriter.create(scale.getLoc(), newScaleType, scale); + + auto scaledA = rewriter.create( + dotOp.getLoc(), a, scale, dotOp.getLhsType()); + + // convert dot instruction + auto newDot = + rewriter.create(dotOp.getLoc(), newRetType, scaledA, b, newAcc); + rewriter.replaceOpWithNewOp(dotOp, oldRetType, newDot); + return success(); + } +}; + #define GEN_PASS_DEF_TRITONGPUACCELERATEMATMUL #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" @@ -397,7 +535,8 @@ class TritonGPUAccelerateMatmulPass auto computeCapability = getNVIDIAComputeCapability(m); mlir::RewritePatternSet patterns(context); - patterns.add(context, computeCapability); + patterns.add(context, + computeCapability); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index db980c5fcaf8..91acba38bf59 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -25,8 +25,7 @@ using namespace triton; SmallVector mmaVersionToInstrShape(int version, const ArrayRef &shape, - RankedTensorType type, - int numWarps) { + Type eltType, int numWarps) { if (version == 1) return {16, 16}; else if (version == 2) { @@ -36,12 +35,11 @@ SmallVector mmaVersionToInstrShape(int version, ret[rank - 2] = 16; return ret; } else if (version == 3) { - unsigned k = 256 / type.getElementTypeBitWidth(); + unsigned k = 256 / eltType.getIntOrFloatBitWidth(); if (shape[0] % 64 != 0 || shape[1] % 8 != 0) { assert(false && "type not supported"); return {0, 0, 0}; } - auto eltType = type.getElementType(); SmallVector validN; // MMAv3 with larger instruction shape is preferred. diff --git a/python/src/ir.cc b/python/src/ir.cc index 61d43a670f2c..9945c6188294 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1,3 +1,4 @@ +#include #include #include #include @@ -204,6 +205,14 @@ void init_triton_ir(py::module &&m) { .value("IEEE", InputPrecision::IEEE) .export_values(); + py::enum_(m, "F8F6F4TY", py::module_local()) + .value("E4M3", F8F6F4Type::E4M3) + .value("E5M2", F8F6F4Type::E5M2) + .value("E2M3", F8F6F4Type::E2M3) + .value("E3M2", F8F6F4Type::E3M2) + .value("E2M1", F8F6F4Type::E2M1) + .export_values(); + py::class_(m, "context", py::module_local()) .def(py::init<>()) .def("printOpOnDiagnostic", @@ -1412,6 +1421,15 @@ void init_triton_ir(py::module &&m) { return self.create(c.getType(), a, b, c, inputPrecision, maxNumImpreciseAcc); }) + .def("create_dot_scaled", + [](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &lhs_scale, + F8F6F4Type lhs_format, mlir::Value &rhs, + std::optional &rhs_scale, F8F6F4Type rhs_format, + mlir::Value &c) -> mlir::Value { + return self.create( + c.getType(), lhs, rhs, c, lhs_scale, + rhs_scale.value_or(Value()), lhs_format, rhs_format); + }) .def("create_floor", [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index da9f5e73e58d..a0b149099dc1 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3315,6 +3315,172 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx +@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps", [ + (M, N, K, col_a, col_b, type_a, type_b, 4) + for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) + for col_a, col_b in itertools.product([True, False], repeat=2) + # We don't test e5m2 as it seems to overflow more easily + # Tested locally and it works fine other than for ~10 entries out of 10_000 + # which are of the size of 10**30 + for type_a in ["e4m3"] + for type_b in ["e4m3"] +]) +def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, device): + if not is_cuda(): + pytest.skip("scaled_dot only supported on CUDA") + else: + cc = torch.cuda.get_device_capability() + if cc < (8, 9): + pytest.skip("float8e4nv not supported on CUDA < 8.9") + + @triton.jit + def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr, + type_b: tl.constexpr): + tl.static_assert(type_b == "e4m3" or type_b == "e5m2", "type_b must be fp8") + IS_FP8: tl.constexpr = type_a == "e4m3" or type_a == "e5m2" + DIV_FACTOR: tl.constexpr = 1 if IS_FP8 else 2 + PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR + PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K + a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0, + PACKED_BLOCK_K_A)[None, :] * stride_a1 + b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0, + BLOCK_N)[None, :] * stride_b1 + + SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 + scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] + + a = tl.load(a_ptr) + b = tl.load(b_ptr) + a_scale = tl.load(scale_a_ptr) + c = tl.dot_scaled(a, a_scale, type_a, b, None, type_b) + out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + tl.store(out_ptr, c) + + @triton.jit + def mxfp_to_bf16_kernel( + x_ptr, + scale_ptr, + mxfp_ptr, + N, + e_bits: tl.constexpr, + m_bits: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ): + # x.shape == (N, 32) for fp8 or (N, 16) for fp4 + # scale.shape == (N,) + # out.shape == (N, 32) + is_fp8: tl.constexpr = e_bits + m_bits == 7 + # fp8: BLOCK_SIZE -> BLOCK_SIZE // 32, 32 + # fp4: BLOCK_SIZE // 2 -> BLOCK_SIZE // 32 , 16 + PARALLEL_DIM: tl.constexpr = BLOCK_SIZE // 32 + LAST_DIM: tl.constexpr = 32 if is_fp8 else 16 + LOAD_SIZE: tl.constexpr = LAST_DIM * PARALLEL_DIM + + offsets = (tl.program_id(0) * LOAD_SIZE + tl.arange(0, PARALLEL_DIM)[:, None] * LAST_DIM + + tl.arange(0, LAST_DIM)[None, :]) + x = tl.load(x_ptr + offsets, mask=offsets < N * LAST_DIM) + + offsets = tl.program_id(0) * PARALLEL_DIM + tl.arange(0, PARALLEL_DIM)[:, None] + scale = tl.load(scale_ptr + offsets, mask=offsets < N) + tl.static_assert(scale.dtype == tl.uint8) + tl.static_assert(x.dtype == tl.uint8) + + scale_bf16 = (scale.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True) + if is_fp8: + if e_bits == 5 and m_bits == 2: + x_f8 = x.to(tl.float8e5, bitcast=True) + x_bf16 = x_f8.to(tl.bfloat16) + # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them! + non_finite_mask: tl.constexpr = ((1 << e_bits) - 1) << m_bits + non_finite_mask_bf16: tl.constexpr = ((1 << 8) - 1) << 7 + x_bf16 = tl.where( + x & non_finite_mask == non_finite_mask, + (x_bf16.to(tl.uint16, bitcast=True) | non_finite_mask_bf16).to(tl.bfloat16, bitcast=True), + x_bf16, + ) + else: + tl.static_assert(e_bits == 4 and m_bits == 3) + x_f8 = x.to(tl.float8e4nv, bitcast=True) + x_bf16 = x_f8.to(tl.bfloat16) + else: + # e2m1 + em0 = x & 0x70 + em1 = x & 0x7 + x0 = (em0.to(tl.uint16) << 2) | ((x & 0x80).to(tl.uint16) << 8) + x1 = (em1.to(tl.uint16) << (2 + 4)) | ((x & 0x8).to(tl.uint16) << (8 + 4)) + # Three cases: + # 1) x is normal and non-zero: Correct bias + x0 = tl.where((em0 & 0x60) != 0, x0 + ((127 - 1) << 7), x0) + x1 = tl.where((em1 & 0x6) != 0, x1 + ((127 - 1) << 7), x1) + # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in bf16 + x0 = tl.where(em0 == 0x10, 16128 | (x0 & 0x8000), x0) + x1 = tl.where(em1 == 0x1, 16128 | (x1 & 0x8000), x1) + # 3) x is zero, do nothing + x_bf16 = tl.interleave(x0, x1).to(tl.bfloat16, bitcast=True) + # Multiplication preserves infs and NaNs in x_bf16 + mxfp = x_bf16 * scale_bf16 + # If scale is NaN, we encode it as an bf16 inf, so we need to correct for that + mxfp = tl.where(scale == 0xFF, float("nan"), mxfp) + + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store(mxfp_ptr + offsets, tl.ravel(mxfp), mask=offsets < N * 32) + + def dot_scale_ref(x, scale, y, type_x, type_y): + e_bits, m_bits = {"e4m3": (4, 3), "e5m2": (5, 2)}[type_x] + type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y] + + # Need to implement fp4 -> fp8 cast to support 1 byte types + comp_dtype = torch.bfloat16 + out_dtype = torch.float32 + + x = x.contiguous() + x_upcast = x.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype) + + N = x_upcast.numel() + BLOCK_SIZE = 512 + grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) + mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=num_warps) + + y_upcast = y.view(type_fp8_y).to(comp_dtype) + return torch.matmul(x_upcast.to(out_dtype), y_upcast.to(out_dtype)) + + torch.manual_seed(0) + + def create_uint8(shape): + return torch.randint(0xff, shape, dtype=torch.uint8, device=device) + + x = create_uint8((K, M)).T if col_a else create_uint8((M, K)) + y = create_uint8((N, K)).T if col_b else create_uint8((K, N)) + scale_x = create_uint8((M, K // 32)) + + z = x.new_empty((M, N), dtype=torch.float32) + pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, + num_warps=num_warps) + + z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b) + + # dot_scale_ref computes the result in higher precision + # so we equalise all the non-finite values + # This also fixes a bug in our upcasting from e5m2 to bf16 where inf is not preserved + non_finite_z = ~z.isfinite() + z_ref[non_finite_z] = z[non_finite_z] + non_finite_ref = ~z_ref.isfinite() + z[non_finite_ref] = z_ref[non_finite_ref] + + # generous rtol set because the ref is more precise than the fused + # (computes in higher dtype) and we are sampling the whole range of floats + torch.testing.assert_close(z, z_ref, equal_nan=True, atol=1e-5, rtol=1e-2) + + # make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + if (max(M, N) * K) // (num_warps * 32) >= 4: + assert 'ld.global.v4' in ptx + if M * N // (num_warps * 32) >= 4: + assert 'st.global.v4' in ptx + assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) + + @pytest.mark.interpreter @pytest.mark.parametrize("B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str", [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 9f100c0a97e0..6502a5348f3e 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -54,6 +54,7 @@ device_assert, device_print, dot, + dot_scaled, dtype, expand_dims, float16, @@ -161,6 +162,7 @@ "device_print", "div_rn", "dot", + "dot_scaled", "dtype", "erf", "exp", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index b73a2f08bb6b..06a15f93fd21 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1559,6 +1559,29 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder) +@builtin +def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, out_dtype=float32, _builder=None): + """ + Returns the matrix product of two blocks in microscaling format. + lhs and rhs use microscaling formats described here: + https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + :param lhs: The first tensor to be multiplied. + :type lhs: 2D tensor of f8, f6 or f4 format packed in int32 format. + :param lhs_scale: Scale factor for lhs tensor. + :type lhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). + :param lhs_format: format of the lhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. + :param rhs: The second tensor to be multiplied. + :type rhs: 2D tensor of f8, f6 or f4 format packed in int32 format. + :param rhs_scale: Scale factor for rhs tensor. + :type rhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). + :param rhs_format: format of the rhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. + :param acc: The accumulator tensor. If not None, the result is added to this tensor. + """ + out_dtype = _constexpr_to_value(out_dtype) + assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment" + return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, out_dtype, _builder) + + # ----------------------- # Non-Atomic Memory Operations # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 44a8dce0d01a..1fdfbadcd290 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1527,6 +1527,48 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona ret_ty) +def _str_to_fp_type(float_format: Optional[str]): + if float_format == 'e4m3': + return ir.F8F6F4TY.E4M3 + if float_format == 'e5m2': + return ir.F8F6F4TY.E5M2 + if float_format == 'e2m3': + return ir.F8F6F4TY.E2M3 + if float_format == 'e3m2': + return ir.F8F6F4TY.E3M2 + if float_format == 'e2m1': + return ir.F8F6F4TY.E2M1 + raise ValueError(f"Invalid float format: {float_format}.") + + +def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, rhs_scale: Optional[tl.tensor], + rhs_format, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() + #TODO: validate types. + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + M, K = lhs.type.shape[-2:] + N = rhs.type.shape[-1] + assert K == rhs.type.shape[-2], f"Reduction dimension should agree; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + assert K >= 64, f"scaled_dot NYI for K < 64. Got {K=}" + B = lhs.type.shape[0] if lhs_rank == 3 else None + + ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N]) + _0 = builder.get_fp32(0) + if acc is None: + acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + lhs_format_enum = _str_to_fp_type(lhs_format) + rhs_format_enum = _str_to_fp_type(rhs_format) + rhs_scale_handle = None if isinstance(rhs_scale, tl.constexpr) else rhs_scale.handle + return tl.tensor( + builder.create_dot_scaled(lhs.handle, lhs_scale.handle, lhs_format_enum, rhs.handle, rhs_scale_handle, + rhs_format_enum, acc_handle), ret_ty) + + # ===----------------------------------------------------------------------===// # Indexing # ===----------------------------------------------------------------------===// diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 728fd8eadfd9..85b37f3ed3a9 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -101,6 +101,7 @@ module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 : #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [0, 1, 2]}> #blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 2, 2], threadsPerWarp = [1, 4, 8], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK: kernel_ tt.func public @kernel_() attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<2x16x16xf32, #blocked> %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked1> @@ -129,8 +130,8 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, {{.*}}, instrShape = [16, 32, 16]}> #blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: check_instrShape_per_warps tt.func @check_instrShape_per_warps(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - // CHECK-LABEL: check_instrShape_per_warps %mask = arith.constant dense : tensor<128x128xi1, #blocked> %zero_f32 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %a = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> @@ -150,6 +151,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 4], instrShape = [16, 8]}> #blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: small_k_size tt.func @small_k_size( %a: tensor<128x16xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, %b: tensor<16x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) @@ -159,3 +161,24 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : tt.return %result : tensor<128x128xf32, #blocked> } } + +// ----- + +// Verify that dot_scaled (mxfp8 x fp8) decomposes as expected +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: dot_scaled + tt.func @dot_scaled( + %a: tensor<128x64xi8, #blocked2>, + %scale: tensor<128x2xi8, #blocked1>, + %b: tensor<64x128xi8, #blocked>) + -> tensor<128x128xf32, #blocked> { + // CHECK: triton_gpu.upcast_mxfp + // CHECK: tt.dot + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %result = tt.dot_scaled %a, %scale, %b, %cst lhs = e4m3 rhs = e4m3 : tensor<128x64xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xi8, #blocked> -> tensor<128x128xf32, #blocked> + tt.return %result : tensor<128x128xf32, #blocked> + } +} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt index 197901d8555c..a944da1c83f1 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -17,6 +17,7 @@ add_triton_library(TritonNVIDIAGPUToLLVM ClusterOpsToLLVM.cpp PTXAsmFormat.cpp Utility.cpp + UpcastMXFPToLLVM.cpp TargetInfo.cpp DEPENDS diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h index c076674f51de..4060378fa42b 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -36,6 +36,11 @@ void populateElementwiseOpToLLVMPatterns( ModuleAxisInfoAnalysis &axisInfoAnalysis, int computeCapability, const TargetInfo &targetInfo, PatternBenefit benefit); +void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit); + void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index d25bb2cc4acc..21f5b706320d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -172,6 +172,8 @@ struct ConvertTritonGPUToLLVM patterns, benefit); mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, patterns, benefit); + mlir::triton::NVIDIA::populateUpcastMXFPToLLVMPatterns( + typeConverter, patterns, targetInfo, benefit); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp new file mode 100644 index 000000000000..aeca44bb46ce --- /dev/null +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -0,0 +1,99 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" + +#include "PatternTritonGPUOpToLLVM.h" + +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { +class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { +private: + const TargetInfoBase &targetInfo; + +public: + UpcastMXFPOpPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + auto tyX = cast(op->getOperandTypes()[0]); + auto operands = adaptor.getOperands(); + + auto xVals = unpackLLElements(loc, operands[0], rewriter); + auto scaleVals = unpackLLElements(loc, operands[1], rewriter); + + Value tid = tid_val(); + auto mod = op->getParentOfType(); + Value warpSize = + i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); + Value warpId = udiv(tid, warpSize); + Value laneId = urem(tid, warpSize); + + auto scale = [&loc, &rewriter](Value v, Value s) -> Value { + // Split bf16x2 into 2 bf16, scale each of them, and pack them back + // TODO Is it true that the bfloats are always packed as bf16x2? + auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty); + auto bf16_1 = bitcast(trunc(i16_ty, lshr(v, i32_val(16))), bf16_ty); + auto scaleIsNan = icmp_eq(s, i8_val(0xff)); + auto scaleBf16 = bitcast(shl(zext(i16_ty, s), i16_val(7)), bf16_ty); + auto scaledBf16_0 = fmul(bf16_0, scaleBf16); + auto scaledBf16_1 = fmul(bf16_1, scaleBf16); + auto i16_0 = bitcast(scaledBf16_0, i16_ty); + auto i16_1 = bitcast(scaledBf16_1, i16_ty); + auto packed = + or_(zext(i32_ty, i16_0), shl(zext(i32_ty, i16_1), i32_val(16))); + // Account for NaN in the scale as per the mxfp specification + auto packed_nan = select(scaleIsNan, i32_val(0x7fff7fff), packed); + return packed_nan; + }; + + // Each thread owns elements of 4 mxfp vectors so we need 4 scales + // Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c + + // 16, c + 17 + auto c = mul(udiv(laneId, i32_val(4)), i32_val(2)); + std::array ci = {c, add(c, i32_val(1)), add(c, i32_val(16)), + add(c, i32_val(17))}; + for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { + auto si = std::array{ + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[0]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[1]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[2]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[3]), + }; + // si indices for the 16 elements in x + std::array siMap = {0, 0, 2, 2, 0, 0, 2, 2, + 1, 1, 3, 3, 1, 1, 3, 3}; + for (int j = 0; j < 16; ++j) { + xVals[16 * i + j] = scale(xVals[16 * i + j], si[siMap[j]]); + } + } + + Value result = + packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType()); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // anonymous namespace + +void mlir::triton::NVIDIA::populateUpcastMXFPToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfo &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +}