Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Linalg] Add conversion between bf16 and f16 #3963

Merged
merged 3 commits into from
Jan 17, 2025
Merged

Conversation

AmosLewis
Copy link
Collaborator

@AmosLewis AmosLewis commented Jan 16, 2025

To fix issue #3962 : 'arith.extf' op operand type 'bf16' and result type 'f16' are cast incompatible

@AmosLewis AmosLewis marked this pull request as ready for review January 16, 2025 21:35
@dan-garvey
Copy link
Collaborator

If others are okay with it we can merge this, but I think for our case we probably want to find what causes this cast to get generated in the first place, because it is guaranteed to cost us in numerics. (The datatypes have big variance in which values they can express, especially going from bf16 to fp16 is bad)

Exponent: FP16 has a 5-bit exponent, while BF16 has an 8-bit exponent
Mantissa: FP16 has a 10-bit mantissa, while BF16 has a 7-bit mantissa

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Dan. arith.bitcast between f16 and bf16 is probably not going to work.

Please add an e2e test that would cover this as a start.

@zjgarvey
Copy link
Collaborator

I'm not sure if this is desirable from a performance standpoint, but you are certainly able to correctly get the conversion to work by doing arith.extf to f32 followed by arith.truncf back to the other f16 type.

Here are some examples you can compile and run to see the outputs:

#map = affine_map<(d0) -> (d0)>
module {
  func.func @convert(%arg0: tensor<1xbf16>) -> tensor<1xf16> {
    %0 = tensor.empty() : tensor<1xf16>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xbf16>) outs(%0 : tensor<1xf16>) {
    ^bb0(%in: bf16, %out: f16):
      %2 = arith.extf %in : bf16 to f32
      %3 = arith.truncf %2 : f32 to f16
      linalg.yield %3 : f16
    } -> tensor<1xf16>
    return %1 : tensor<1xf16>
  }
}

running this on --input='1xbf16=1.0 returns

result[0]: hal.buffer_view
1xf16=1

But

#map = affine_map<(d0) -> (d0)>
module {
  func.func @convert(%arg0: tensor<1xbf16>) -> tensor<1xf16> {
    %0 = tensor.empty() : tensor<1xf16>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xbf16>) outs(%0 : tensor<1xf16>) {
    ^bb0(%in: bf16, %out: f16):
      %2 = arith.bitcast %in : bf16 to f16
      linalg.yield %2 : f16
    } -> tensor<1xf16>
    return %1 : tensor<1xf16>
  }
}

Yields

result[0]: hal.buffer_view
1xf16=1.875

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Jan 17, 2025

@zjgarvey I tried to add a e2e test in torch-mlir but failed when lower linalg to refbackend. argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<1x1xbf16>'. Which means bf16 cannot be set in e2e test.

*** RUNNING TEST: TensorBfloat16ToFloat16_basic ***
Compiling TensorBfloat16ToFloat16_basic...
Running TensorBfloat16ToFloat16_basic...
loc("/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/cast.py":165:0): error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<1x1xbf16>'
TORCH_VERSION_FOR_COMPARISON = 2.6.0.dev20241216
loc("/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/cast.py":165:0): error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<1x1xbf16>'
FAIL - "TensorBfloat16ToFloat16_basic"

Unexpected outcome summary: (fx_importer)

****** Failed tests - 1 tests
    FAIL - "TensorBfloat16ToFloat16_basic"
        Runtime error: Traceback (most recent call last):
          File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/framework.py", line 354, in compile_and_run_test
            trace = config.run(compiled, golden_trace)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
          File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/configs/fx_importer_backend.py", line 49, in run
            self._export_run(artifact, trace)
          File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/configs/fx_importer_backend.py", line 132, in _export_run
            module = self._backend.compile(module)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
          File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py", line 229, in compile
            run_pipeline_with_repro_report(
          File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 127, in run_pipeline_with_repro_report
            raise TorchMlirCompilerError(trimmed_message) from None
        torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Linalg-on-Tensors IR to LLVM with RefBackend failed with the following diagnostics:


        python exception: Failure while executing pass pipeline
class TensorBfloat16ToFloat16(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args(
        [
            None,
            ([-1, -1], torch.bfloat16, True),
        ]
    )
    def forward(self, x):
        return x.to(torch.float16)


@register_test_case(module_factory=lambda: TensorBfloat16ToFloat16())
def TensorBfloat16ToFloat16_basic(module, tu: TestUtils):
    module.forward(torch.tensor([[1]], dtype=torch.bfloat16))

The created linalg

#map = affine_map<(d0, d1) -> (0, 0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @TensorBfloat16ToFloat16(%arg0: tensor<1x1xbf16> loc("/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/cast.py":165:0)) -> tensor<1x1xf16> {
    %0 = tensor.empty() : tensor<1x1xf16> loc(#loc1)
    %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<1x1xbf16>) outs(%0 : tensor<1x1xf16>) {
    ^bb0(%in: bf16 loc("/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/cast.py":165:0), %out: f16 loc("/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/cast.py":165:0)):
      %2 = arith.bitcast %in : bf16 to f16 loc(#loc1)
      linalg.yield %2 : f16 loc(#loc1)
    } -> tensor<1x1xf16> loc(#loc1)
    return %1 : tensor<1x1xf16> loc(#loc1)
  } loc(#loc1)
} loc(#loc)
#loc = loc(unknown)

I also tried to add onnx.cast in Shark-Testsuite with bf16, but it will also failed since
Failed test at stage construct_inputs with exception: Numpy doesn't support bfloat16. Please consider modifying the boundary types. nod-ai/SHARK-TestSuite#430

@AmosLewis AmosLewis force-pushed the castf16 branch 2 times, most recently from a1e685d to d3fe084 Compare January 17, 2025 01:23
@AmosLewis AmosLewis requested a review from zjgarvey January 17, 2025 01:23
@AmosLewis AmosLewis changed the title [Linalg] Add arith::bitcast between same width data type convert [Linalg] Add convertion between bf16 and f16 Jan 17, 2025
Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, Chi. I think a small change in the lit test would be preferable, and renaming the PR to fix the spelling error for conversion. Otherwise this looks good to me.

test/Conversion/TorchToLinalg/elementwise.mlir Outdated Show resolved Hide resolved
@AmosLewis AmosLewis changed the title [Linalg] Add convertion between bf16 and f16 [Linalg] Add conversion between bf16 and f16 Jan 17, 2025
@AmosLewis AmosLewis requested a review from zjgarvey January 17, 2025 21:48
Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Chi! LGTM

@zjgarvey zjgarvey merged commit f42c7e4 into llvm:main Jan 17, 2025
3 checks passed
@AmosLewis AmosLewis deleted the castf16 branch January 20, 2025 05:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants