Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect Numerics for a f32 Depthwise Conv Op #18600

Closed
zjgarvey opened this issue Sep 26, 2024 · 3 comments · Fixed by #19356
Closed

Incorrect Numerics for a f32 Depthwise Conv Op #18600

zjgarvey opened this issue Sep 26, 2024 · 3 comments · Fixed by #19356
Assignees
Labels
bug 🐞 Something isn't working

Comments

@zjgarvey
Copy link
Contributor

What happened?

The following depthwise convolution op (ingested from an onnx model) seems to generate outputs through IREE on cpu differing substantially from the results generated from onnxruntime's CPU implementation.

module {
  func.func @main(%arg0: !torch.vtensor<[1,256,112,112],f32>, %arg1: !torch.vtensor<[256,1,3,3],f32>, %arg2: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,256,56,56],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.Conv"(%arg0, %arg1, %arg2) {torch.onnx.group = 256 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,256,112,112],f32>, !torch.vtensor<[256,1,3,3],f32>, !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,256,56,56],f32> 
    return %0 : !torch.vtensor<[1,256,56,56],f32>
  }
}

Compiling succeeds on mi300 but crashes on iree-run-module.

Steps to reproduce your issue

Set up the Shark Test suite for a local IREE build (Linux):

copy/paste to start setup of test suite. Might have to change the original python executable name to whatever your system/ IREE build prefers.

git clone https://github.com/nod-ai/SHARK-TestSuite.git
cd SHARK-TestSuite/alt_e2eshark/
python3.11 -m venv ts.venv
source ts.venv/bin/activate
pip install --upgrade pip
pip install -r base_requirements.txt
pip install --no-deps -r torch_mlir_requirements.txt

Then edit the following with the path to your IREE build (if built with python bindings).

IREE_BUILD_DIR = <replace with path to iree-build> && \
source ${IREE_BUILD_DIR}/.env && export PYTHONPATH

If you do not have iree-compile and iree-run-module on your path, add them.

run test

python run.py -t conv_depthwise -v -m cl-onnx-iree

inspect results

The run.py script should generate a sub-directory ./test-run/conv_depthwise_stride_2/. With the mode cl-onnx-iree, this should also generate a /commands/ directory with compile and run-module commands. Inspect inference_comparison.log to see input, output, and gold output printouts.

Test on GPU

python run.py -t conv_depthwise -v -m cl-onnx-iree -b rocm -d hip -ica "iree-hip-target=gfx942"

Fails on iree-run-module (stage is called "compiled_inference").

What component(s) does this issue relate to?

Frontends, MLIR, Compiler

Version information

Local build at commit ae6e5d3

Additional context

Affects a few models, e.g. "maxvit_rmlp_base_rw_224.sw_in12k".

@zjgarvey
Copy link
Contributor Author

If it is helpful, here is some associated linalg IR:

module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main(%arg0: tensor<1x256x112x112xf32>, %arg1: tensor<256x1x3x3xf32>, %arg2: tensor<256xf32>) -> tensor<1x256x56x56xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
      tensor.yield %cst : f32
    } : tensor<1x256x112x112xf32> to tensor<1x256x114x114xf32>
    %0 = tensor.empty() : tensor<1x256x56x56xf32>
    %broadcasted = linalg.broadcast ins(%arg2 : tensor<256xf32>) outs(%0 : tensor<1x256x56x56xf32>) dimensions = [0, 2, 3] 
    %collapsed = tensor.collapse_shape %arg1 [[0, 1], [2], [3]] : tensor<256x1x3x3xf32> into tensor<256x3x3xf32>
    %1 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%padded, %collapsed : tensor<1x256x114x114xf32>, tensor<256x3x3xf32>) outs(%broadcasted : tensor<1x256x56x56xf32>) -> tensor<1x256x56x56xf32>
    return %1 : tensor<1x256x56x56xf32>
  }
}

@AmosLewis
Copy link
Contributor

Same numeric failure for conv op in convnext_nano.d1h_in1k model nod-ai/SHARK-TestSuite#403
onnx.mlir

module {
  func.func @main(%arg0: !torch.vtensor<[1,80,72,72],f32>, %arg1: !torch.vtensor<[80,1,7,7],f32>, %arg2: !torch.vtensor<[80],f32>) -> !torch.vtensor<[1,80,72,72],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.Conv"(%arg0, %arg1, %arg2) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 80 : si64, torch.onnx.kernel_shape = [7 : si64, 7 : si64], torch.onnx.pads = [3 : si64, 3 : si64, 3 : si64, 3 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,80,72,72],f32>, !torch.vtensor<[80,1,7,7],f32>, !torch.vtensor<[80],f32>) -> !torch.vtensor<[1,80,72,72],f32> 
    return %0 : !torch.vtensor<[1,80,72,72],f32>
  }
}

linalg.mlir

module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main(%arg0: tensor<1x80x72x72xf32>, %arg1: tensor<80x1x7x7xf32>, %arg2: tensor<80xf32>) -> tensor<1x80x72x72xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %padded = tensor.pad %arg0 low[0, 0, 3, 3] high[0, 0, 3, 3] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
      tensor.yield %cst : f32
    } : tensor<1x80x72x72xf32> to tensor<1x80x78x78xf32>
    %0 = tensor.empty() : tensor<1x80x72x72xf32>
    %broadcasted = linalg.broadcast ins(%arg2 : tensor<80xf32>) outs(%0 : tensor<1x80x72x72xf32>) dimensions = [0, 2, 3] 
    %collapsed = tensor.collapse_shape %arg1 [[0, 1], [2], [3]] : tensor<80x1x7x7xf32> into tensor<80x7x7xf32>
    %1 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded, %collapsed : tensor<1x80x78x78xf32>, tensor<80x7x7xf32>) outs(%broadcasted : tensor<1x80x72x72xf32>) -> tensor<1x80x72x72xf32>
    return %1 : tensor<1x80x72x72xf32>
  }
}

@zjgarvey
Copy link
Contributor Author

zjgarvey commented Dec 3, 2024

I found the source of the numeric failure. There is a pass in iree that is replacing the bias with a zero init tensor, then reporting a match failure before it can add the bias back to the original. Will update shortly.

zjgarvey added a commit that referenced this issue Dec 5, 2024
…ps pass (#19356)

This moves match failure checks before modifying linalg ops, and loosens
the check for identity map access to the output tensor.

### Context:

Specific depthwise convolution ops were encountering numeric failures.
See <#18600> and
<#19339>. I noticed that the bias
was not affecting the output values, and tracked down where the bias was
getting deleted.

The issue is that the pass `DetatchElementwiseFromNamedOps` was
modifying the `depthwise_conv` op to use a zero-fill *before* checking
for some match failures. This resulted in a partial application of the
pattern where the original bias did not get added back to the modified
linalg op result.

The depthwise conv ops were specifically failing to have an identity map
for the output tensor access.

For example:

```mlir
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @torch_jit(%arg0: tensor<1x96x56x56xf32>, %arg1: tensor<96x1x7x7xf32>, %arg2: tensor<96xf32>) -> tensor<1x96x56x56xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %padded = tensor.pad %arg0 low[0, 0, 3, 3] high[0, 0, 3, 3] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
      tensor.yield %cst : f32
    } : tensor<1x96x56x56xf32> to tensor<1x96x62x62xf32>
    %0 = tensor.empty() : tensor<1x96x56x56xf32>
    %broadcasted = linalg.broadcast ins(%arg2 : tensor<96xf32>) outs(%0 : tensor<1x96x56x56xf32>) dimensions = [0, 2, 3] 
    %collapsed = tensor.collapse_shape %arg1 [[0, 1], [2], [3]] : tensor<96x1x7x7xf32> into tensor<96x7x7xf32>
    %1 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded, %collapsed : tensor<1x96x62x62xf32>, tensor<96x7x7xf32>) outs(%broadcasted : tensor<1x96x56x56xf32>) -> tensor<1x96x56x56xf32>
    return %1 : tensor<1x96x56x56xf32>
  }
}
```

generalizes to

```mlir
#map = affine_map<(d0, d1, d2, d3) -> (d1)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @torch_jit(%arg0: tensor<1x96x56x56xf32>, %arg1: tensor<96x1x7x7xf32>, %arg2: tensor<96xf32>) -> tensor<1x96x56x56xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %padded = tensor.pad %arg0 low[0, 0, 3, 3] high[0, 0, 3, 3] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
      tensor.yield %cst : f32
    } : tensor<1x96x56x56xf32> to tensor<1x96x62x62xf32>
    %0 = tensor.empty() : tensor<1x96x56x56xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<96xf32>) outs(%0 : tensor<1x96x56x56xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x96x56x56xf32>
    %collapsed = tensor.collapse_shape %arg1 [[0, 1], [2], [3]] : tensor<96x1x7x7xf32> into tensor<96x7x7xf32>
    %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%padded, %collapsed : tensor<1x96x62x62xf32>, tensor<96x7x7xf32>) outs(%1 : tensor<1x96x56x56xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %3 = arith.mulf %in, %in_0 : f32
      %4 = arith.addf %out, %3 : f32
      linalg.yield %4 : f32
    } -> tensor<1x96x56x56xf32>
    return %2 : tensor<1x96x56x56xf32>
  }
}
```

For some reason, the channel dim `d3` appears after the spatial dims
(`d1` and `d2`) for this particular op.

---------

Signed-off-by: zjgarvey <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

4 participants