Skip to content

Commit

Permalink
[DispatchCreation] Drop fusion restriction for stride != 1 conv (#19634)
Browse files Browse the repository at this point in the history
There was a restriction on aggressive fusion that prevented fusion with
non-unit stride convs due to lack of support in the backend. The backend
now supports such cases so we can re-enable those fusions.
  • Loading branch information
qedawkins authored Jan 8, 2025
1 parent c75b686 commit 9b4906e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -425,21 +425,6 @@ static bool canUseInOperandAsInitOperand(OpOperand *inOperand,
return true;
}

/// All operations in a dispatch should be vectorized, which isnt the case today
/// This is an explicit list of operations that arent vectorized for now
/// requiring special handling for now in dispatch region formation to avoid
/// large stack allocations.
static bool isVectorizedAlways(Operation *producer) {
// TODO(#17155) : This is a black list of operations that are not vectorized
// today (under the aggressive fusion flag). Remove this blacklist to return
// true always.
if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfOp>(producer)) {
auto strides = convOp.getStrides();
return strides.isSplat() && strides.getSplatValue<int64_t>() == 1;
}
return true;
}

/// Returns true if this is a fusable use, while fusing a root with its
/// consumer.
static bool
Expand Down Expand Up @@ -554,9 +539,7 @@ isFusableWithConsumer(OpOperand &fusedOperand,
// Under aggressive fusion assume that the dispatches are vectorized. In which
// case we dont need to account for the subsequent stack allocation condition.
if (options.aggressiveFusion) {
if (isVectorizedAlways(producer)) {
return true;
}
return true;
}

// While fusing with consumer, the result of the root might not be the final
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2268,7 +2268,7 @@ util.func @mixed_conv(%arg0 : tensor<2x130x130x16xf16>, %arg1 : tensor<3x3x16x32

// -----

util.func @mixed_conv_unsupported(%arg0 : tensor<2x130x130x320xf16>, %arg1 : tensor<3x3x320x320xf16>) -> tensor<2x64x64x320xf16> {
util.func @mixed_conv_stride_2(%arg0 : tensor<2x130x130x320xf16>, %arg1 : tensor<3x3x320x320xf16>) -> tensor<2x64x64x320xf16> {
%empty = tensor.empty() : tensor<2x64x64x320xf32>
%cst = arith.constant 0.0 : f32
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<2x64x64x320xf32>) -> tensor<2x64x64x320xf32>
Expand All @@ -2288,16 +2288,15 @@ util.func @mixed_conv_unsupported(%arg0 : tensor<2x130x130x320xf16>, %arg1 : ten
} -> tensor<2x64x64x320xf16>
util.return %truncf : tensor<2x64x64x320xf16>
}
// CHECK-LABEL: func public @mixed_conv_unsupported(
// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.workgroups
// CHECK-LABEL: func public @mixed_conv_stride_2(
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.workgroups
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: flow.dispatch.tensor.store %[[CONV]]
// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.workgroups
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[CONV]]
// CHECK: flow.dispatch.tensor.store %[[GENERIC]]
// CHECK: util.return %[[DISPATCH1]]
// CHECK: util.return %[[DISPATCH]]

// -----

Expand Down

0 comments on commit 9b4906e

Please sign in to comment.