Skip to content

Commit

Permalink
Improve getFinalValue and remove no longer needed folding helper func…
Browse files Browse the repository at this point in the history
…tions

Signed-off-by: Tiotto, Ettore <[email protected]>
  • Loading branch information
etiotto committed Jan 27, 2025
1 parent c77d178 commit d9fdda8
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 260 deletions.
21 changes: 10 additions & 11 deletions test/Triton/Intel/RaiseToBlockPointers/addptr_mul_value_const.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,15 @@ module {
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32
// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
// CHECK: [[VAR_2_:%.+]] = arith.muli [[PARAM_2_]], [[CST_2048_i32]] : i32
// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64
// CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[VAR_0_]], [[VAR_2_]] : i32
// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_3_]], [[CST_1_i64]] : i64
// CHECK-DAG: [[VAR_6_:%.+]] = arith.trunci [[VAR_5_]] : i64 to i32
// CHECK-DAG: [[VAR_7_:%.+]] = arith.divui [[VAR_4_]], [[VAR_6_]] : i32
// CHECK-DAG: [[VAR_8_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[VAR_5_]]], {{\[}}[[VAR_7_]]] {{.*}} : <tensor<1024xbf16>>
// CHECK-DAG: [[VAR_9_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_0_]]] {{.*}} : <tensor<1024xbf16>>
// CHECK: [[VAR_10_:%.+]] = tt.load [[VAR_8_]] : !tt.ptr<tensor<1024xbf16>>
// CHECK: tt.store [[VAR_9_]], [[VAR_10_]] : !tt.ptr<tensor<1024xbf16>>
// CHECK: [[VAR_1_:%.+]] = arith.muli [[PARAM_2_]], [[CST_2048_i32]] : i32
// CHECK-DAG: [[VAR_2_:%.+]] = arith.extsi [[PARAM_2_]] : i32 to i64
// CHECK-DAG: [[VAR_3_:%.+]] = arith.addi [[VAR_0_]], [[VAR_1_]] : i32
// CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[VAR_2_]], [[CST_1_i64]] : i64
// CHECK-DAG: [[VAR_5_:%.+]] = arith.trunci [[VAR_4_]] : i64 to i32
// CHECK-DAG: [[VAR_6_:%.+]] = arith.divui [[VAR_3_]], [[VAR_5_]] : i32
// CHECK-DAG: [[VAR_7_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[VAR_4_]]], {{\[}}[[VAR_6_]]] {{.*}} : <tensor<1024xbf16>>
// CHECK-DAG: [[VAR_8_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_0_]]] {{.*}} : <tensor<1024xbf16>>
// CHECK: [[VAR_9_:%.+]] = tt.load [[VAR_7_]] : !tt.ptr<tensor<1024xbf16>>
// CHECK: tt.store [[VAR_8_]], [[VAR_9_]] : !tt.ptr<tensor<1024xbf16>>
// CHECK: tt.return
// CHECK: }
Original file line number Diff line number Diff line change
Expand Up @@ -106,24 +106,28 @@ module {
// CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
// CHECK-DAG: [[CST_1_i32:%.+]] = arith.constant 1 : i32
// CHECK: [[VAR_28_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : <tensor<128x64xbf16>>
// CHECK: [[VAR_31_:%.+]] = arith.index_cast [[PARAM_9_]] : i32 to index
// CHECK: [[VAR_32_:%.+]] = arith.index_cast [[VAR_31_]] : index to i64
// CHECK: [[VAR_38_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : <tensor<64x256xbf16>>
// CHECK-DAG: [[VAR_39_:%.+]] = arith.muli [[PARAM_7_]], [[CST_64_i32]] : i32
// CHECK-DAG: [[VAR_40_:%.+]] = arith.muli [[PARAM_8_]], [[CST_64_i32]] : i32
// CHECK: [[VAR_41_:%.+]]:3 = scf.for {{.*}} iter_args([[VAR_arg13_:%.+]] = [[VAR_cst_]], [[VAR_arg14_:%.+]] = [[VAR_28_]], [[VAR_arg15_:%.+]] = [[VAR_38_]]) -> (tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>) : i32 {
// CHECK-DAG: [[VAR_54_:%.+]] = tt.load [[VAR_arg14_]] : !tt.ptr<tensor<128x64xbf16>>
// CHECK-DAG: [[VAR_55_:%.+]] = tt.load [[VAR_arg15_]] : !tt.ptr<tensor<64x256xbf16>>
// CHECK: [[VAR_56_:%.+]] = tt.dot [[VAR_54_]], [[VAR_55_]], [[VAR_cst_]], inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32>
// CHECK-DAG: [[VAR_57_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_56_]] : tensor<128x256xf32>
// CHECK-DAG: [[VAR_58_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[CST_0_i32]], [[VAR_39_]]] : <tensor<128x64xbf16>>
// CHECK-DAG: [[VAR_59_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[CST_0_i32]], [[VAR_40_]]] : <tensor<64x256xbf16>>
// CHECK: scf.yield [[VAR_57_]], [[VAR_58_]], [[VAR_59_]] : tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>
// CHECK: [[VAR_20_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64
// CHECK: [[VAR_21_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64
// CHECK: [[VAR_22_:%.+]] = arith.divui {{.*}}, [[PARAM_6_]] : i32
// CHECK: [[VAR_23_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_20_]], [[VAR_21_]]], {{\[}}[[VAR_22_]], [[CST_0_i32]]] {{.*}} : <tensor<128x64xbf16>>
// CHECK: [[VAR_24_:%.+]] = arith.extsi [[PARAM_8_]] : i32 to i64
// CHECK: [[VAR_25_:%.+]] = arith.muli {{.*}}, [[PARAM_9_]] : i32
// CHECK: [[VAR_26_:%.+]] = arith.extsi [[PARAM_9_]] : i32 to i64
// CHECK: [[VAR_27_:%.+]] = arith.divui [[VAR_25_]], [[PARAM_9_]] : i32
// CHECK: [[VAR_28_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_24_]], [[VAR_26_]]], {{\[}}[[CST_0_i32]], [[VAR_27_]]] {{.*}} : <tensor<64x256xbf16>>
// CHECK-DAG: [[VAR_29_:%.+]] = arith.muli [[PARAM_7_]], [[CST_64_i32]] : i32
// CHECK-DAG: [[VAR_30_:%.+]] = arith.muli [[PARAM_8_]], [[CST_64_i32]] : i32
// CHECK: [[VAR_31_:%.+]]:3 = scf.for {{.*}} iter_args([[VAR_arg13_:%.+]] = [[VAR_cst_]], [[VAR_arg14_:%.+]] = [[VAR_23_]], [[VAR_arg15_:%.+]] = [[VAR_28_]]) -> (tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>) : i32 {
// CHECK-DAG: [[VAR_40_:%.+]] = tt.load [[VAR_arg14_]] : !tt.ptr<tensor<128x64xbf16>>
// CHECK-DAG: [[VAR_41_:%.+]] = tt.load [[VAR_arg15_]] : !tt.ptr<tensor<64x256xbf16>>
// CHECK: [[VAR_42_:%.+]] = tt.dot [[VAR_40_]], [[VAR_41_]], [[VAR_cst_]], inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32>
// CHECK-DAG: [[VAR_43_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_42_]] : tensor<128x256xf32>
// CHECK-DAG: [[VAR_44_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[CST_0_i32]], [[VAR_29_]]] : <tensor<128x64xbf16>>
// CHECK-DAG: [[VAR_45_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[CST_0_i32]], [[VAR_30_]]] : <tensor<64x256xbf16>>
// CHECK: scf.yield [[VAR_43_]], [[VAR_44_]], [[VAR_45_]] : tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>
// CHECK: }
// CHECK-DAG: [[VAR_42_:%.+]] = arith.truncf [[VAR_41_]]#0 : tensor<128x256xf32> to tensor<128x256xbf16>
// CHECK-DAG: [[VAR_43_:%.+]] = arith.index_cast [[PARAM_10_]] : i32 to index
// CHECK: [[VAR_53_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : <tensor<128x256xbf16>>
// CHECK: tt.store [[VAR_53_]], [[VAR_42_]] : !tt.ptr<tensor<128x256xbf16>>
// CHECK-DAG: [[VAR_32_:%.+]] = arith.truncf [[VAR_31_]]#0 : tensor<128x256xf32> to tensor<128x256xbf16>
// CHECK: [[VAR_39_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : <tensor<128x256xbf16>>
// CHECK: tt.store [[VAR_39_]], [[VAR_32_]] : !tt.ptr<tensor<128x256xbf16>>
// CHECK: tt.return
// CHECK: }
Loading

0 comments on commit d9fdda8

Please sign in to comment.