Skip to content

Commit

Permalink
Add bf16 to f16 lit test
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed Jan 17, 2025
1 parent d3fe084 commit 3a1153a
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions test/Conversion/TorchToLinalg/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,37 @@ func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3
%0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}

// -----

// CHECK-LABEL: func.func @elementwise_todtype_bf162f16(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> {
// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,?,32,128],bf16> -> tensor<1x?x32x128xbf16>
// CHECK: %[[INT5:.*]] = torch.constant.int 5
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[CONSTANT1_1:.*]] = arith.constant 1 : index
// CHECK: %[[CONSTANT1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[INPUT]], %[[CONSTANT1]] : tensor<1x?x32x128xbf16>
// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2 : index
// CHECK: %[[CONSTANT_32:.*]] = arith.constant 32 : index
// CHECK: %[[CONSTANT_3:.*]] = arith.constant 3 : index
// CHECK: %[[CONSTANT_128:.*]] = arith.constant 128 : index
// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM]]) : tensor<1x?x32x128xf16>
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[INPUT]] : tensor<1x?x32x128xbf16>) outs(%[[EMPTY]] : tensor<1x?x32x128xf16>) {
// CHECK: ^bb0(%[[LHS:.*]]: bf16, %[[RHS:.*]]: f16):
// CHECK: %[[EXTF:.*]] = arith.extf %[[LHS]] : bf16 to f32
// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[EXTF]] : f32 to f16
// CHECK: linalg.yield %[[TRUNCF]] : f16
// CHECK: } -> tensor<1x?x32x128xf16>
// CHECK: %[[CAST:.*]] = tensor.cast %[[GENERIC]] : tensor<1x?x32x128xf16> to tensor<1x?x32x128xf16>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<1x?x32x128xf16> -> !torch.vtensor<[1,?,32,128],f16>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,?,32,128],f16>
// CHECK: }
func.func @elementwise_todtype_bf162f16(%arg0: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> {
%int5 = torch.constant.int 5
%false = torch.constant.bool false
%none = torch.constant.none
%0 = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,?,32,128],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,?,32,128],f16>
return %0 : !torch.vtensor<[1,?,32,128],f16>
}

0 comments on commit 3a1153a

Please sign in to comment.