Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support accumulating GEMMs in TileAndFuse with intrinsic without needing c promotion #19546

Open
nirvedhmeshram opened this issue Dec 20, 2024 · 9 comments

Comments

@nirvedhmeshram
Copy link
Contributor

nirvedhmeshram commented Dec 20, 2024

Currently for accumulating GEMM we fail to bufferize if we dont do c promotion in TileAndFuse pipeline when using intrinsics. See dump here . I know there are some tranforms that are still in development but I wasnt sure they will serve this case as well.

@nirvedhmeshram nirvedhmeshram changed the title Support accumulating GEMMs in TileAndFuse without needing c promotion Support accumulating GEMMs in TileAndFuse with intrinssic without needing c promotion Dec 20, 2024
@nirvedhmeshram
Copy link
Contributor Author

nirvedhmeshram commented Dec 27, 2024

Here is what is causing this to fail to bufferize, after GPUFuseAndHoistParallelLoopsPass
We have the following access

%read_write_input = flow.dispatch.tensor.load  ... -> tensor<32x16x32x16xi32>
%workgroup_scf_forall = scf.forall ...  shared_outs(%arg2 = %read_write_input ) -> (tensor<32x16x32x16xi32>) {
  %11 = tensor.empty() : tensor<4x16x4x16xi32>
  %subgroup_scf_forall = scf.forall ... shared_outs(%arg5 = %11)  -> (tensor<4x16x4x16xi32>) {
    %extracted_slice = tensor.extract_slice %arg2
    %extracted_slice_0 = tensor.extract_slice %arg5
      %thread_scf_forall = scf.forall ...  shared_outs(%arg7 = %extracted_slice_0) -> (tensor<2x16x2x16xi32>) {
      %acc_input = tensor.extract_slice %extracted_slice
      %some_intrinsic_compute (...,  %acc_input )
       %acc_output= tensor.extract_slice %arg7

The point being that the accumlator input slice was derived from %read_write_input but accumlator ouput is written back to a slice derived from an empty tensor.

Later after EliminateEmptyTensorsPass
we are left with almost the same access pattern accept that we have

  %extracted_slice = tensor.extract_slice %arg2
  %subgroup_scf_forall = scf.forall ... shared_outs(%arg5 = %extracted_slice)  -> (tensor<4x16x4x16xi32>) {

At this stage the acc_input is a slice of %arg2 while the acc_ouput is written to %arg7 both of which are slices from %read_write_input , This is what is perceived by OneShotAnalysis in hasReadAfterWriteInterference as a RaW Conflict following which it introduces a copy from which we cant recover. As a hack I bypassed the logic in the analysis to say there is no conflict and then things work out with a few unnecessary sub views that i am assuming lower level codegen will take care of and I verified that I got correct numerics with that.
So I see following possible solutions, not sure if any of them are good ones.

  1. During GPUFuseAndHoistParallelLoopsPass add pattern(s) so that we directly pass %arg2 to %arg5 and make the %acc_input take %arg7
  2. Same thing as 1 but after/during EliminateEmptyTensorsPass
  3. Improve the hasReadAfterWriteInterference logic to make it understand this case is not a RaW

@nirvedhmeshram nirvedhmeshram changed the title Support accumulating GEMMs in TileAndFuse with intrinssic without needing c promotion Support accumulating GEMMs in TileAndFuse with intrinsic without needing c promotion Dec 27, 2024
@hanhanW
Copy link
Contributor

hanhanW commented Jan 6, 2025

I spot an issue in your dump. The issue that I spot in the dump is that the output binding is ReadOnly. This is because your input program write the result to the function argument. IREE is not smart enough to create a global buffer for the output tensor, and maybe it should not happen -- I can't interpret the meaning of writing the result into input argument. It is a tensor, not a pointer.

  func.func @bmm(%arg0: tensor<512x128xi8>, %arg1: tensor<512x128xi8>, %arg2: tensor<512x512xi32>) -> tensor<512x512xi32> {
    %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%arg2 : tensor<512x512xi32>) -> tensor<512x512xi32>
    return %0 : tensor<512x512xi32>
  }

I think it is better to have a dump with tensor.empty variant. E.g.,

  func.func @bmm(%arg0: tensor<512x128xi8>, %arg1: tensor<512x128xi8>) -> tensor<512x512xi32> {
    %arg2 = tensor.empty() : tensor<512x512xi32>
    %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%arg2 : tensor<512x512xi32>) -> tensor<512x512xi32>
    return %0 : tensor<512x512xi32>
  }

It is clearer because we explicitly ask IREE to create a global buffer for the tensor and output the result at the end. I did not run the example myself because I don't know what the compilation command is.

@hanhanW
Copy link
Contributor

hanhanW commented Jan 6, 2025

If the issue comes from tests/e2e/matmul/, we should probably just fix the generated input programs to the tensor.empty form.

@hanhanW
Copy link
Contributor

hanhanW commented Jan 6, 2025

The other solution might be running something like RemoveArgOutsDependency pattern at global level, which is similar to RemoveCstOutsDependency pattern in the ConvertToDestinationPassingStylePass.

struct RemoveCstOutsDependency
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(linalg::LinalgOp op,
PatternRewriter &rewriter) const override {
rewriter.startOpModification(op);
bool modifiedOutput = false;
Location loc = op.getLoc();
for (OpOperand &opOperand : op.getDpsInitsMutable()) {
ElementsAttr attr;
if (!matchPattern(opOperand.get(), m_Constant(&attr)))
continue;
if (!attr.isSplat())
continue;
auto type = llvm::dyn_cast<RankedTensorType>(attr.getType());
if (!type)
continue;
TypedAttr scalarAttr = attr.getValues<TypedAttr>()[0];
modifiedOutput = true;
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, type.getShape(), type.getElementType());
Value cstOp = rewriter.create<arith::ConstantOp>(loc, scalarAttr);
Value fillOp =
rewriter.create<linalg::FillOp>(loc, cstOp, emptyTensor).result();
op->setOperand(opOperand.getOperandNumber(), fillOp);
}
if (!modifiedOutput) {
rewriter.cancelOpModification(op);
return failure();
}
rewriter.finalizeOpModification(op);
return success();
}
};

@nirvedhmeshram
Copy link
Contributor Author

@MaheshRavishankar WDYT of the two suggestions from @hanhanW above. Based on our conversations previously we want to support accumulating GEMMS without transforming them to non-accumulating GEMM + elementwise add, That being said to make progress on the TileAndFuse side where we have codegen issues described above on it, we decided we could go down the elementwise add path only for that pipeline till we have a proper solution, so thats what I tried in this PR
#19587

@qedawkins
Copy link
Contributor

My memory of the IR we looked at was that conversion to DPS was not the issue, the issue was a RaW conflict arising because folders kicked in on the read that didn't happen on the writes.

I think it is better to have a dump with tensor.empty variant. E.g.,

func.func @bmm(%arg0: tensor<512x128xi8>, %arg1: tensor<512x128xi8>) -> tensor<512x512xi32> {
%arg2 = tensor.empty() : tensor<512x512xi32>
%0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%arg2 : tensor<512x512xi32>) -> tensor<512x512xi32>
return %0 : tensor<512x512xi32>
}

linalg.matmul_transpose_b accumulates into its destination (so it does both a read and a "write"), so the output of this program is just some math on unitinialized memory (i.e. undefined).

The issue that I spot in the dump is that the output binding is ReadOnly

I see the readwrite binding in the dump

%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<readwrite:tensor<512x512xi32>>

Maybe the binding layout looks confusing because it lists the layouts of all bindings rather than just the one it is representing, so in this case it's the last entry #hal.pipeline.binding<storage_buffer, Indirect> that represents the access permissions, which looks right to me.

@hanhanW
Copy link
Contributor

hanhanW commented Jan 7, 2025

ah, I see. I did not notice that it lists all the layouts..

@MaheshRavishankar
Copy link
Contributor

I spot an issue in your dump. The issue that I spot in the dump is that the output binding is ReadOnly. This is because your input program write the result to the function argument. IREE is not smart enough to create a global buffer for the output tensor, and maybe it should not happen -- I can't interpret the meaning of writing the result into input argument. It is a tensor, not a pointer.

func.func @bmm(%arg0: tensor<512x128xi8>, %arg1: tensor<512x128xi8>, %arg2: tensor<512x512xi32>) -> tensor<512x512xi32> {
%0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%arg2 : tensor<512x512xi32>) -> tensor<512x512xi32>
return %0 : tensor<512x512xi32>
}

This is a perfectly valid program and should be supported correctly in IREE (I think it does today). The semantics is that the initial value of the matmul_transpose_b is passed in as an input and we do D = C + A * B and return D. This program is same as

func.func @bmm(%arg0: memref<512x128xi8>, %arg1: memref<512x128xi8>, %arg2: memref<512x512xi32>) {
  linalg.matmul_transpose_b ins(%arg0, %arg1 : memref<512x128xi8>, memref<512x128xi8>) outs(%arg2 : memref<512x512xi32>) -> tensor<512x512xi32>
  return
}

I think it is better to have a dump with tensor.empty variant. E.g.,

func.func @bmm(%arg0: tensor<512x128xi8>, %arg1: tensor<512x128xi8>) -> tensor<512x512xi32> {
%arg2 = tensor.empty() : tensor<512x512xi32>
%0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%arg2 : tensor<512x512xi32>) -> tensor<512x512xi32>
return %0 : tensor<512x512xi32>
}
It is clearer because we explicitly ask IREE to create a global buffer for the tensor and output the result at the end. I did not run the example myself because I don't know what the compilation command is.

Yeah, Like Quinn said this is just going to be garbage output (and there is a flag in IREE where we force zero-initialize undefined tensors like this which just happens to work).

@MaheshRavishankar WDYT of the two suggestions from @hanhanW above. Based on our conversations previously we want to support accumulating GEMMS without transforming them to non-accumulating GEMM + elementwise add, That being said to make progress on the TileAndFuse side where we have codegen issues described above on it, we decided we could go down the elementwise add path only for that pipeline till we have a proper solution, so thats what I tried in this PR #19587

I dont think those work... This is really a pipeline issue and is a local fix for the pipeline itself. We shouldnt be trying to work around this at the full program level cause those will be hard to control. Converting GEMM to fill + GEMM + accumulate is kind of a hack, but just something for us to fix in the long run. For now this is fine.

nirvedhmeshram added a commit that referenced this issue Jan 7, 2025
Converts dispatches with accumulating GEMMs that are doing in place
read/write to GEMM + elementwise add.
This is needed for the TileAndFuse path until we find a more permanent
fix for #19546

---------

Signed-off-by: Nirvedh Meshram <[email protected]>
@hanhanW
Copy link
Contributor

hanhanW commented Jan 8, 2025

This is a perfectly valid program and should be supported correctly in IREE (I think it does today). The semantics is that the initial value of the matmul_transpose_b is passed in as an input and we do D = C + A * B and return D. This program is same as

func.func @bmm(%arg0: memref<512x128xi8>, %arg1: memref<512x128xi8>, %arg2: memref<512x512xi32>) {
  linalg.matmul_transpose_b ins(%arg0, %arg1 : memref<512x128xi8>, memref<512x128xi8>) outs(%arg2 : memref<512x512xi32>) -> tensor<512x512xi32>
  return
}

I see what you meant. It makes sense to me, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants