Skip to content

Commit

Permalink
[BACKEND] Canonicalize ReshapeOp even if not allowing reorder (#5752)
Browse files Browse the repository at this point in the history
The `getAllowReorder` check was added in #2676, but the
canonicalizations are value-preserving so this is not required.

Specifically:
- `reshape(splat) -> splat`, order is irrelevant for splat
- `reshape(reshape) -> reshape`, reshape essentially treats the input as
1d, so the input reshape has no effect.
  • Loading branch information
peterbell10 authored Jan 30, 2025
1 parent c75c6b0 commit a5d90d0
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 23 deletions.
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def TT_BroadcastOp : TT_Op<"broadcast", [Pure,

let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";

let hasCanonicalizeMethod = 1;
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/Triton/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
set(LLVM_TARGET_DEFINITIONS Canonicalize.td)
mlir_tablegen(TritonCanonicalize.inc -gen-rewriters)
add_public_tablegen_target(TritonCanonicalizeIncGen)

add_triton_library(TritonIR
Dialect.cpp
Ops.cpp
Expand All @@ -7,6 +11,7 @@ add_triton_library(TritonIR

DEPENDS
TritonTableGen
TritonCanonicalizeIncGen

LINK_LIBS PUBLIC
MLIRIR
Expand Down
17 changes: 17 additions & 0 deletions lib/Dialect/Triton/IR/Canonicalize.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef TT_PATTERNS
#define TT_PATTERNS

include "mlir/IR/PatternBase.td"
include "triton/Dialect/Triton/IR/TritonOps.td"

// broadcast(splat(x)) -> splat(x)
def BroadcastSplatPattern :
Pat<(TT_BroadcastOp (TT_SplatOp $x)),
(TT_SplatOp $x)>;

// broadcast(broadcast(x)) -> broadcast(x)
def BroadcastBroadcastPattern :
Pat<(TT_BroadcastOp (TT_BroadcastOp $x)),
(TT_BroadcastOp $x)>;

#endif
37 changes: 18 additions & 19 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Tools/LinearLayout.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir {
Expand All @@ -33,6 +32,8 @@ void LoadOp::getEffects(
// enum attribute definitions
#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc"

#include "TritonCanonicalize.inc"

namespace mlir {
namespace triton {

Expand Down Expand Up @@ -648,23 +649,27 @@ OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) {
}

//-- ReshapeOp --
template <typename OpType>
LogicalResult canonicalizeViewOrBroadcast(OpType op,
PatternRewriter &rewriter) {
LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) {
if (op.getEfficientLayout())
return failure();

auto definingOp = op.getSrc().getDefiningOp();
if (!definingOp) {
return failure();
}

// view(view) -> view
if (auto parentView = dyn_cast<OpType>(definingOp)) {
rewriter.replaceOpWithNewOp<OpType>(op, TypeRange({op.getType()}),
parentView->getOperands(),
parentView->getAttrs());
// reshape(reshape) -> reshape
if (auto parentReshape = dyn_cast<ReshapeOp>(definingOp)) {
// Allow reorder if either reshape allowed it
const bool allowReorder =
(op.getAllowReorder() || parentReshape.getAllowReorder());
rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(),
parentReshape.getSrc(), allowReorder,
op.getEfficientLayout());
return success();
}

// view(splat) -> splat
// reshape(splat) -> splat
if (auto splat = dyn_cast<SplatOp>(definingOp)) {
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getSrc());
return success();
Expand All @@ -673,12 +678,6 @@ LogicalResult canonicalizeViewOrBroadcast(OpType op,
return failure();
}

LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) {
if (!op.getAllowReorder() || op.getEfficientLayout())
return failure();
return canonicalizeViewOrBroadcast(op, rewriter);
}

OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (getType() == getSrc().getType() && !getAllowReorder()) {
// no-op
Expand Down Expand Up @@ -763,9 +762,9 @@ LogicalResult FpToFpOp::verify() {
}

//-- BroadcastOp --
LogicalResult BroadcastOp::canonicalize(BroadcastOp op,
PatternRewriter &rewriter) {
return canonicalizeViewOrBroadcast(op, rewriter);
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<BroadcastSplatPattern, BroadcastBroadcastPattern>(context);
}

OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
Expand Down
28 changes: 25 additions & 3 deletions test/Triton/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,8 @@ tt.func @test_canonicalize_expand_dims(%arg0: tensor<f32>, %arg1: tensor<1xf32>)
tt.return %ed, %bc2 : tensor<1x8xf32>, tensor<8x8xf32>
}


// CHECK-LABEL: @test_canonicalize_view
tt.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>) {
tt.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>, tensor<2x2x2xf32>) {
%view0 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<2x4xf32>
// CHECK: %{{.*}} = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<4x2xf32>
%view1 = tt.reshape %view0 allow_reorder : tensor<2x4xf32> -> tensor<4x2xf32>
Expand All @@ -294,7 +293,30 @@ tt.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (te
// CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<8xf32>
%add = arith.addf %view3, %arg0 : tensor<8xf32>

tt.return %view1, %view2, %add : tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>
// CHECK: %{{.*}} = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<2x2x2xf32>
%reshape = tt.reshape %view0 : tensor<2x4xf32> -> tensor<2x2x2xf32>

tt.return %view1, %view2, %add, %reshape : tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>, tensor<2x2x2xf32>
}

// CHECK-LABEL: @test_canonicalize_reshape
tt.func @test_canonicalize_reshape(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>, tensor<2x2x2xf32>) {
%reshape0 = tt.reshape %arg0 : tensor<8xf32> -> tensor<2x4xf32>
// CHECK: %{{.*}} = tt.reshape %arg0 : tensor<8xf32> -> tensor<4x2xf32>
%reshape1 = tt.reshape %reshape0 : tensor<2x4xf32> -> tensor<4x2xf32>

%splat = tt.splat %arg1 : tensor<f32> -> tensor<8xf32>
// CHECK: %{{.*}} = tt.splat %arg1 : tensor<f32> -> tensor<2x2x2xf32>
%reshape2 = tt.reshape %splat : tensor<8xf32> -> tensor<2x2x2xf32>

%reshape3 = tt.reshape %arg0 : tensor<8xf32> -> tensor<8xf32>
// CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<8xf32>
%add = arith.addf %reshape3, %arg0 : tensor<8xf32>

// CHECK: %{{.*}} = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<2x2x2xf32>
%view = tt.reshape %reshape0 allow_reorder : tensor<2x4xf32> -> tensor<2x2x2xf32>

tt.return %reshape1, %reshape2, %add, %view : tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>, tensor<2x2x2xf32>
}

// CHECK-LABEL: @test_canonicalize_broadcast
Expand Down

0 comments on commit a5d90d0

Please sign in to comment.