diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 46827326abd0..c0396621bdea 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -225,28 +225,6 @@ static bool bwdFilter(Operation *op) { mlir::TypeID::get()); } -static SmallVector getTransposeOrder(int rank) { - assert(rank >= 2); - auto transOrder = llvm::to_vector<2>(llvm::seq(rank - 2)); - transOrder.push_back(rank - 1); - transOrder.push_back(rank - 2); - return transOrder; -} - -static DotOp transposeDotOp(PatternRewriter &rewriter, DotOp dotOp) { - auto rank = dotOp.getResult().getType().getRank(); - Value a = dotOp.getA(); - Value b = dotOp.getB(); - Value c = dotOp.getC(); - auto transOrder = getTransposeOrder(rank); - a = rewriter.create(a.getLoc(), a, transOrder); - b = rewriter.create(b.getLoc(), b, transOrder); - c = rewriter.create(c.getLoc(), c, transOrder); - return rewriter.create(dotOp.getLoc(), c.getType(), b, a, c, - dotOp.getInputPrecision(), - dotOp.getMaxNumImpreciseAcc()); -} - // Finds the first different bitwidth in the chain of shape-preserving // unary ops that x depends on. // There are two primary scenarios: @@ -336,7 +314,6 @@ class BlockedToMMA : public mlir::OpRewritePattern { bool aFromLoad = comesFromLoadOrBlockArg(dotOp.getA()); bool bFromLoad = comesFromLoadOrBlockArg(dotOp.getB()); - bool transpose = false; auto origDotOp = dotOp; Value a = dotOp.getA(); @@ -402,12 +379,6 @@ class BlockedToMMA : public mlir::OpRewritePattern { dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); } - if (transpose) { - auto rank = dotOp.getResult().getType().getRank(); - auto transOrder = getTransposeOrder(rank); - newDot = rewriter.create(newDot->getLoc(), newDot->getResult(0), - transOrder); - } // convert dot instruction rewriter.replaceOpWithNewOp(origDotOp, origDotOp.getType(), newDot->getResult(0));