From 6f3eb95904c315da67215f89f90a9002df589a65 Mon Sep 17 00:00:00 2001 From: Emilio Cota Date: Fri, 31 Jan 2025 17:05:23 -0800 Subject: [PATCH] [xla:cpu] kernel_api_ir_builder: expose helpers to get KernelParams This paves the way for upcoming work. PiperOrigin-RevId: 721954340 --- xla/backends/cpu/codegen/emitters/ir/BUILD | 6 + .../cpu/codegen/emitters/transforms/BUILD | 3 + .../cpu/codegen/kernel_api_ir_builder.cc | 24 +- .../cpu/codegen/kernel_api_ir_builder.h | 8 + .../cpu/codegen/kernel_api_ir_builder_test.cc | 28 ++ .../gpu/codegen/emitters/emitter_base.cc | 22 +- xla/backends/gpu/codegen/emitters/ir/BUILD | 7 + .../tests/reduce_column_small/f32_2.hlo | 2 +- .../tests/reduce_column_small/f32_32_v2.hlo | 2 +- .../tests/reduce_column_small/f32_8_v2.hlo | 2 +- .../reduce_column_small/s8_f32_32_v4.hlo | 2 +- .../gpu/codegen/emitters/transforms/BUILD | 11 - .../gpu/codegen/emitters/transforms/passes.h | 10 - .../gpu/codegen/emitters/transforms/passes.td | 193 ------------ .../emitters/transforms/tests/inlining.mlir | 295 ------------------ xla/backends/gpu/codegen/triton/BUILD | 1 + .../gpu/codegen/triton/fusion_emitter.cc | 3 +- xla/codegen/emitters/BUILD | 1 + xla/codegen/emitters/ir/BUILD | 6 + xla/codegen/emitters/transforms/BUILD | 21 ++ .../transforms/convert_pure_call_ops.cc} | 8 +- .../transforms/erase_dead_functions.cc | 10 +- .../emitters/transforms/lower_xla_to_scf.cc} | 65 ++-- .../merge_pointers_to_same_slice.cc | 6 +- xla/codegen/emitters/transforms/passes.h | 11 + xla/codegen/emitters/transforms/passes.td | 228 ++++++++++++-- .../codegen/emitters/transforms/peel_loops.cc | 10 +- .../transforms/propagate_slice_indices.cc | 8 +- .../emitters/transforms/simplify_affine.cc | 9 +- .../emitters/transforms/simplify_arith.cc | 10 +- .../tests/convert_pure_calls_ops.mlir} | 4 +- .../tests/lower_xla_loops_to_scf.mlir} | 2 +- .../transforms/tests/lower_xla_to_scf.mlir} | 2 +- .../tests/merge_pointers_to_same_slice.mlir | 2 +- .../emitters/transforms/tests/peel_loops.mlir | 2 +- .../tests/propagate_slice_indices.mlir | 2 +- .../transforms/tests/simplify_affine.mlir | 4 +- .../transforms/tests/simplify_arith.mlir | 2 +- .../transforms/tests/unswitch_loops.mlir | 2 +- .../emitters/transforms/unswitch_loops.cc | 6 +- xla/tsl/concurrency/BUILD | 2 + 41 files changed, 410 insertions(+), 632 deletions(-) delete mode 100644 xla/backends/gpu/codegen/emitters/transforms/tests/inlining.mlir rename xla/{backends/gpu/codegen/emitters/transforms/convert_xla_gpu_pure_call_ops.cc => codegen/emitters/transforms/convert_pure_call_ops.cc} (91%) rename xla/{backends/gpu => }/codegen/emitters/transforms/erase_dead_functions.cc (92%) rename xla/{backends/gpu/codegen/emitters/transforms/lower_xla_gpu_to_scf.cc => codegen/emitters/transforms/lower_xla_to_scf.cc} (89%) rename xla/{backends/gpu => }/codegen/emitters/transforms/merge_pointers_to_same_slice.cc (97%) rename xla/{backends/gpu => }/codegen/emitters/transforms/peel_loops.cc (96%) rename xla/{backends/gpu => }/codegen/emitters/transforms/propagate_slice_indices.cc (92%) rename xla/{backends/gpu => }/codegen/emitters/transforms/simplify_affine.cc (97%) rename xla/{backends/gpu => }/codegen/emitters/transforms/simplify_arith.cc (98%) rename xla/{backends/gpu/codegen/emitters/transforms/tests/convert_xla_gpu_pure_calls.mlir => codegen/emitters/transforms/tests/convert_pure_calls_ops.mlir} (90%) rename xla/{backends/gpu/codegen/emitters/transforms/tests/lower_xla_gpu_loops_to_scf.mlir => codegen/emitters/transforms/tests/lower_xla_loops_to_scf.mlir} (98%) rename xla/{backends/gpu/codegen/emitters/transforms/tests/lower_xla_gpu_to_scf.mlir => codegen/emitters/transforms/tests/lower_xla_to_scf.mlir} (99%) rename xla/{backends/gpu => }/codegen/emitters/transforms/tests/merge_pointers_to_same_slice.mlir (98%) rename xla/{backends/gpu => }/codegen/emitters/transforms/tests/peel_loops.mlir (97%) rename xla/{backends/gpu => }/codegen/emitters/transforms/tests/propagate_slice_indices.mlir (94%) rename xla/{backends/gpu => }/codegen/emitters/transforms/tests/simplify_affine.mlir (98%) rename xla/{backends/gpu => }/codegen/emitters/transforms/tests/simplify_arith.mlir (99%) rename xla/{backends/gpu => }/codegen/emitters/transforms/tests/unswitch_loops.mlir (94%) rename xla/{backends/gpu => }/codegen/emitters/transforms/unswitch_loops.cc (97%) diff --git a/xla/backends/cpu/codegen/emitters/ir/BUILD b/xla/backends/cpu/codegen/emitters/ir/BUILD index f1afa0fcc7bf3..32df97b461b85 100644 --- a/xla/backends/cpu/codegen/emitters/ir/BUILD +++ b/xla/backends/cpu/codegen/emitters/ir/BUILD @@ -1,4 +1,6 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -16,6 +18,7 @@ package_group( td_library( name = "xla_cpu_td_files", srcs = glob(["*.td"]), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ "@llvm-project//mlir:BuiltinDialectTdFiles", @@ -25,6 +28,7 @@ td_library( gentbl_cc_library( name = "xla_cpu_dialect_inc_gen", + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -43,6 +47,7 @@ gentbl_cc_library( gentbl_cc_library( name = "xla_cpu_types_inc_gen", + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -67,6 +72,7 @@ gentbl_cc_library( gentbl_cc_library( name = "xla_cpu_ops_inc_gen", + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( diff --git a/xla/backends/cpu/codegen/emitters/transforms/BUILD b/xla/backends/cpu/codegen/emitters/transforms/BUILD index 1dc0eeeb23a61..070c6b9324250 100644 --- a/xla/backends/cpu/codegen/emitters/transforms/BUILD +++ b/xla/backends/cpu/codegen/emitters/transforms/BUILD @@ -1,4 +1,6 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -15,6 +17,7 @@ package_group( gentbl_cc_library( name = "passes_inc_gen", + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ diff --git a/xla/backends/cpu/codegen/kernel_api_ir_builder.cc b/xla/backends/cpu/codegen/kernel_api_ir_builder.cc index 0028b86f4725e..af1f531a70b25 100644 --- a/xla/backends/cpu/codegen/kernel_api_ir_builder.cc +++ b/xla/backends/cpu/codegen/kernel_api_ir_builder.cc @@ -243,39 +243,39 @@ absl::StatusOr GetUniqueSlice( return buffer_assignment->GetUniqueSlice(instruction, index); } +} // namespace + absl::StatusOr> -GetKernelArgumentsParameters(const HloInstruction* instruction, - const BufferAssignment* buffer_assignment) { - std::vector arguments; +KernelApiIrBuilder::GetKernelArgumentsParameters( + const HloInstruction* instruction, + const BufferAssignment* buffer_assignment) { + std::vector arguments; for (HloInstruction* operand : instruction->operands()) { for (auto& indexed : ShapeUtil::GetLeafShapes(operand->shape())) { TF_ASSIGN_OR_RETURN( BufferAllocation::Slice slice, GetUniqueSlice(buffer_assignment, operand, indexed.index)); - arguments.push_back( - KernelApiIrBuilder::KernelParameter{indexed.shape, slice}); + arguments.push_back(KernelParameter{indexed.shape, slice}); } } return arguments; } absl::StatusOr> -GetKernelResultsParameters(const HloInstruction* instruction, - const BufferAssignment* buffer_assignment) { - std::vector results; +KernelApiIrBuilder::GetKernelResultsParameters( + const HloInstruction* instruction, + const BufferAssignment* buffer_assignment) { + std::vector results; for (auto& indexed : ShapeUtil::GetLeafShapes(instruction->shape())) { TF_ASSIGN_OR_RETURN( BufferAllocation::Slice slice, GetUniqueSlice(buffer_assignment, instruction, indexed.index)); - results.push_back( - KernelApiIrBuilder::KernelParameter{indexed.shape, slice}); + results.push_back(KernelParameter{indexed.shape, slice}); } return results; } -} // namespace - auto KernelApiIrBuilder::Options::FromHloModuleConfig( const HloModuleConfig& config) -> Options { return KernelApiIrBuilder::Options{ diff --git a/xla/backends/cpu/codegen/kernel_api_ir_builder.h b/xla/backends/cpu/codegen/kernel_api_ir_builder.h index 08fce19aba82e..0bd1338457fad 100644 --- a/xla/backends/cpu/codegen/kernel_api_ir_builder.h +++ b/xla/backends/cpu/codegen/kernel_api_ir_builder.h @@ -113,6 +113,14 @@ class KernelApiIrBuilder { static std::unique_ptr CreateModule(absl::string_view name, llvm::LLVMContext& context); + static absl::StatusOr> + GetKernelArgumentsParameters(const HloInstruction* instruction, + const BufferAssignment* buffer_assignment); + + static absl::StatusOr> + GetKernelResultsParameters(const HloInstruction* instruction, + const BufferAssignment* buffer_assignment); + private: ThreadDims EmitKernelThreadDims(llvm::IRBuilderBase& builder, llvm::Value* call_frame); diff --git a/xla/backends/cpu/codegen/kernel_api_ir_builder_test.cc b/xla/backends/cpu/codegen/kernel_api_ir_builder_test.cc index 04b25ec25c5fa..9867624530492 100644 --- a/xla/backends/cpu/codegen/kernel_api_ir_builder_test.cc +++ b/xla/backends/cpu/codegen/kernel_api_ir_builder_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -294,5 +295,32 @@ TEST_F(KernelApiIrBuilderTest, MixedBuffers) { EXPECT_TRUE(prototype.invariant_arguments.contains(0)); } +TEST_F(KernelApiIrBuilderTest, GetKernelParams) { + llvm::LLVMContext context; + auto module = std::make_unique("test", context); + constexpr absl::string_view hlo_text = R"( + HloModule m + ENTRY main { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT add.0 = f32[2,2] add(p0, p1) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer_assignment, RunBufferAssignment(*hlo)); + const auto* root = hlo->entry_computation()->root_instruction(); + TF_ASSERT_OK_AND_ASSIGN(auto args, + KernelApiIrBuilder::GetKernelArgumentsParameters( + root, buffer_assignment.get())); + EXPECT_EQ(args.size(), 2); + EXPECT_THAT(args[0].shape.dimensions(), ::testing::ElementsAre(2, 2)); + EXPECT_THAT(args[1].shape.dimensions(), ::testing::ElementsAre(2, 2)); + TF_ASSERT_OK_AND_ASSIGN(auto results, + KernelApiIrBuilder::GetKernelResultsParameters( + root, buffer_assignment.get())); + EXPECT_EQ(results.size(), 1); + EXPECT_THAT(results[0].shape.dimensions(), ::testing::ElementsAre(2, 2)); +} + } // namespace } // namespace xla::cpu diff --git a/xla/backends/gpu/codegen/emitters/emitter_base.cc b/xla/backends/gpu/codegen/emitters/emitter_base.cc index b18fd31f19a75..c9c176194c976 100644 --- a/xla/backends/gpu/codegen/emitters/emitter_base.cc +++ b/xla/backends/gpu/codegen/emitters/emitter_base.cc @@ -574,17 +574,17 @@ absl::Status EmitterBase::RunPassPipeline( } void AddXlaGpuOpsOptimizationPasses(mlir::OpPassManager& pm) { - pm.addNestedPass(CreateSimplifyArithPass()); + pm.addNestedPass(emitters::CreateSimplifyArithPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); - pm.addPass(CreateEraseDeadFunctionsPass()); + pm.addPass(emitters::CreateEraseDeadFunctionsPass()); pm.addPass(mlir::createCSEPass()); } void AddLoopTransformationPasses(mlir::OpPassManager& pm, const se::DeviceDescription& device) { pm.addNestedPass( - CreateLowerXlaGpuToScfPass(device.threads_per_warp())); + emitters::CreateLowerXlaToScfPass(device.threads_per_warp())); pm.addNestedPass(CreateFuseLoopsPass()); pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) { // CSE after inlining because inlining can introduce duplicates. @@ -592,15 +592,15 @@ void AddLoopTransformationPasses(mlir::OpPassManager& pm, })); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); - pm.addNestedPass(CreatePeelLoopsPass()); - pm.addNestedPass(CreateLowerXlaGpuLoopsToScfPass()); + pm.addNestedPass(emitters::CreatePeelLoopsPass()); + pm.addNestedPass(emitters::CreateLowerXlaLoopsToScfPass()); pm.addPass(mlir::mhlo::createConvertToSignlessPass()); - pm.addPass(CreatePropagateSliceIndicesPass()); + pm.addPass(emitters::CreatePropagateSliceIndicesPass()); pm.addPass(emitters::CreateFlattenTensorsPass()); // We need LICM before unswitching loops, because our loop unswitcher only // detects for loops with a single if inside them. pm.addPass(mlir::createLoopInvariantCodeMotionPass()); - pm.addNestedPass(CreateUnswitchLoopsPass()); + pm.addNestedPass(emitters::CreateUnswitchLoopsPass()); // We need LICM again after unswitching, because that can introduce new // opportunities for LICM. This would not be necessary if LICM also moved // instructions over ifs. @@ -613,17 +613,17 @@ void AddLoopTransformationPasses(mlir::OpPassManager& pm, void AddLoweringPasses(mlir::OpPassManager& pm, const se::DeviceDescription& device) { - pm.addNestedPass(CreateConvertPureCallOpsPass()); + pm.addNestedPass(emitters::CreateConvertPureCallOpsPass()); pm.addPass(emitters::CreateLowerTensorsPass(device)); pm.addPass(mlir::createConvertComplexToStandardPass()); - pm.addPass(CreateMergePointersToSameSlicePass()); + pm.addPass(emitters::CreateMergePointersToSameSlicePass()); // LowerTensors creates new affine.apply ops. Fold and CSE them so // simplify-affine has maximally folded expressions to work with. pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); - pm.addNestedPass(CreateSimplifyArithPass()); - pm.addPass(CreateSimplifyAffinePass()); + pm.addNestedPass(emitters::CreateSimplifyArithPass()); + pm.addPass(emitters::CreateSimplifyAffinePass()); pm.addPass(CreateConvertIndexTypePass()); // simplify-affine lowers most affine.apply ops, but if it can't prove a // division or modulo is unsigned, affine.apply ops will remain. diff --git a/xla/backends/gpu/codegen/emitters/ir/BUILD b/xla/backends/gpu/codegen/emitters/ir/BUILD index 2cfab11eea7d6..39fbc1709fc00 100644 --- a/xla/backends/gpu/codegen/emitters/ir/BUILD +++ b/xla/backends/gpu/codegen/emitters/ir/BUILD @@ -1,4 +1,6 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -16,6 +18,7 @@ package_group( td_library( name = "xla_gpu_td_files", srcs = glob(["*.td"]), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ "//xla/codegen/emitters/ir:xla_td_files", @@ -30,6 +33,7 @@ td_library( gentbl_cc_library( name = "xla_gpu_dialect_inc_gen", + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -48,6 +52,7 @@ gentbl_cc_library( gentbl_cc_library( name = "xla_gpu_ops_inc_gen", + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -66,6 +71,7 @@ gentbl_cc_library( gentbl_cc_library( name = "xla_gpu_attrs_inc_gen", + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -98,6 +104,7 @@ gentbl_cc_library( gentbl_cc_library( name = "xla_gpu_types_inc_gen", + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( diff --git a/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/f32_2.hlo b/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/f32_2.hlo index 81c6bebf03f68..530d8a10f8939 100644 --- a/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/f32_2.hlo +++ b/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/f32_2.hlo @@ -1,5 +1,5 @@ // RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce -// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always +// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always add { %p0 = f32[] parameter(0) diff --git a/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/f32_32_v2.hlo b/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/f32_32_v2.hlo index 923f19b24e196..a0c2cb27189fd 100644 --- a/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/f32_32_v2.hlo +++ b/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/f32_32_v2.hlo @@ -1,5 +1,5 @@ -// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always +// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always add { %p0 = f32[] parameter(0) diff --git a/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/f32_8_v2.hlo b/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/f32_8_v2.hlo index 175ced82445ca..8c4ed106ae6a9 100644 --- a/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/f32_8_v2.hlo +++ b/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/f32_8_v2.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always +// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always // RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce add { diff --git a/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/s8_f32_32_v4.hlo b/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/s8_f32_32_v4.hlo index 781449df049c8..5605fbbdb902e 100644 --- a/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/s8_f32_32_v4.hlo +++ b/xla/backends/gpu/codegen/emitters/tests/reduce_column_small/s8_f32_32_v4.hlo @@ -1,4 +1,4 @@ -// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always +// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always add { %p0 = f32[] parameter(0) diff --git a/xla/backends/gpu/codegen/emitters/transforms/BUILD b/xla/backends/gpu/codegen/emitters/transforms/BUILD index 75e77387da1f7..313afb411ce5a 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/BUILD +++ b/xla/backends/gpu/codegen/emitters/transforms/BUILD @@ -40,17 +40,8 @@ cc_library( srcs = [ "convert_float_nvidia.cc", "convert_index_type.cc", - "convert_xla_gpu_pure_call_ops.cc", - "erase_dead_functions.cc", "fuse_loops.cc", - "lower_xla_gpu_to_scf.cc", - "merge_pointers_to_same_slice.cc", "optimize_loops.cc", - "peel_loops.cc", - "propagate_slice_indices.cc", - "simplify_affine.cc", - "simplify_arith.cc", - "unswitch_loops.cc", "vectorize_loads_stores.cc", ], hdrs = ["passes.h"], @@ -61,7 +52,6 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/gpu/codegen/emitters/ir:xla_gpu", - "//xla/codegen/emitters:elemental_hlo_to_mlir", "//xla/codegen/emitters/ir:xla", "//xla/codegen/emitters/transforms:atomic_rmw_utils", "//xla/hlo/analysis:indexing_analysis", @@ -74,7 +64,6 @@ cc_library( "//xla/stream_executor:semantic_version", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", diff --git a/xla/backends/gpu/codegen/emitters/transforms/passes.h b/xla/backends/gpu/codegen/emitters/transforms/passes.h index 8f6a674bab3c6..0a0c86c819ca0 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/passes.h +++ b/xla/backends/gpu/codegen/emitters/transforms/passes.h @@ -37,18 +37,8 @@ std::unique_ptr CreateConvertFloatNvidiaPass(); std::optional> MaybeCreateConvertFloatNvidiaPass( const se::DeviceDescription& device_description); std::unique_ptr CreateConvertIndexTypePass(); -std::unique_ptr CreateConvertPureCallOpsPass(); -std::unique_ptr CreateEraseDeadFunctionsPass(); -std::unique_ptr CreateLowerXlaGpuToScfPass(int64_t warp_size = 32); -std::unique_ptr CreateLowerXlaGpuLoopsToScfPass(); -std::unique_ptr CreateMergePointersToSameSlicePass(); std::unique_ptr CreateOptimizeLoopsPass(); std::unique_ptr CreateFuseLoopsPass(); -std::unique_ptr CreatePeelLoopsPass(); -std::unique_ptr CreatePropagateSliceIndicesPass(); -std::unique_ptr CreateSimplifyAffinePass(); -std::unique_ptr CreateSimplifyArithPass(); -std::unique_ptr CreateUnswitchLoopsPass(); std::unique_ptr CreateVectorizeLoadsAndStoresPass( const std::string& gpu_device_info = ""); std::unique_ptr CreateVectorizeLoadsAndStoresPass( diff --git a/xla/backends/gpu/codegen/emitters/transforms/passes.td b/xla/backends/gpu/codegen/emitters/transforms/passes.td index 8f759d1bb2222..716e1fb676e14 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/passes.td +++ b/xla/backends/gpu/codegen/emitters/transforms/passes.td @@ -18,88 +18,6 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def PropagateSliceIndicesPass : - Pass<"xla-gpu-propagate-slice-indices", "mlir::ModuleOp"> { - let summary = "Propagates slice indices from the entry function to all callees."; - - let description = [{ - Propagates xla.slice_index attributes from the function with the xla.entry - attribute to all other functions. - }]; - - let dependentDialects = [ - "mlir::func::FuncDialect" - ]; - - let constructor = "CreatePropagateSliceIndicesPass()"; -} - -def ConvertPureCallOpsPass - : Pass<"xla-gpu-convert-pure-call-ops", "mlir::func::FuncOp"> { - let summary = "Converts xla_gpu.pure_call to func.call"; - let description = [{ - We use xla_gpu.pure_call ops for calls to enable CSE and other - transformations (e.g. LICM). This pass rewrites our custom ops to standard - ops. - }]; - let dependentDialects = [ - "mlir::func::FuncDialect", - "xla::XlaDialect" - ]; - let constructor = "CreateConvertPureCallOpsPass()"; -} - -def MergePointersToSameSlicePass : - Pass<"xla-gpu-merge-pointers", "mlir::ModuleOp"> { - let summary = "Merges pointers that share slices."; - - let description = [{ - When a function has multiple pointer arguments with the same slice index, - merges them. - }]; - - let dependentDialects = [ - "mlir::func::FuncDialect" - ]; - - let constructor = "CreateMergePointersToSameSlicePass()"; -} - -def SimplifyArithPass : Pass<"xla-gpu-simplify-arith", "mlir::func::FuncOp"> { - let summary = "Simplifies arith using XLA's range-aware simplifier."; - - let description = [{ - We often emit bounds checks that are statically known to be satisfied. - This pass removes them. - }]; - - let dependentDialects = [ - "mlir::arith::ArithDialect", - "mlir::func::FuncDialect", - ]; - - let constructor = "CreateSimplifyArithPass()"; -} - -def SimplifyAffinePass : Pass<"xla-gpu-simplify-affine", "mlir::ModuleOp"> { - let summary = "Simplifies affine.apply using XLA's range-aware simplifier."; - - let description = [{ - The standard affine canonicalizer cannot simplify all expressions, since - it is unaware of range information. This pass uses `xla.range` attributes - on arguments and ops for simplification. It also lowers floordiv and mod - to simpler expressions than lower-affine. This pass only works for - expressions for which we can prove the LHS of mod and div is nonnegative. - }]; - - let dependentDialects = [ - "mlir::affine::AffineDialect", "mlir::func::FuncDialect", - "mlir::scf::SCFDialect", - ]; - - let constructor = "CreateSimplifyAffinePass()"; -} - def ConvertIndexTypePass : Pass<"xla-gpu-convert-index-type", "mlir::ModuleOp"> { let summary = "Converts index types to module data layout index type."; @@ -129,58 +47,6 @@ def ConvertFloatNvidiaPass : Pass<"xla-gpu-convert-float-nvidia", "mlir::ModuleO let constructor = "CreateConvertFloatNvidiaPass()"; } -def LowerXlaGpuToScfPass : - Pass<"xla-gpu-lower-xla-gpu-to-scf", "mlir::func::FuncOp"> { - let summary = "Lowers xla_gpu to SCF."; - - let dependentDialects = [ - "mlir::gpu::GPUDialect", "mlir::LLVM::LLVMDialect", "mlir::scf::SCFDialect", - "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect", - "xla::XlaDialect", "mlir::vector::VectorDialect", - ]; - - let options = [ - Option<"warp_size", "warp_size", "int64_t", /*default=*/"32", "Warp size.">, - ]; - let constructor = "CreateLowerXlaGpuToScfPass()"; -} - -def LowerXlaGpuLoopsToScfPass : Pass< - "xla-gpu-lower-xla-gpu-loops-to-scf", "mlir::func::FuncOp"> { - let summary = "Lowers xla_gpu.loop to SCF."; - - let description = [{ - This pass is separate from lower-xla-gpu-to-scf because - lower-xla-gpu-to-scf, inliner, peeling and lower-xla-gpu-loops-to-scf - have to run in that order. - }]; - - let dependentDialects = [ - "mlir::scf::SCFDialect", - "mlir::tensor::TensorDialect", - "xla::gpu::XlaGpuDialect", - "xla::XlaDialect", - ]; - - let constructor = "CreateLowerXlaGpuLoopsToScfPass()"; -} - -def EraseDeadFunctionsPass : Pass<"xla-erase-dead-functions", "mlir::ModuleOp"> { - let summary = "Deletes unused functions"; - - let description = [{ - Deletes functions that are not called. - }]; - - let dependentDialects = [ - "mlir::func::FuncDialect", - "xla::gpu::XlaGpuDialect", - "xla::XlaDialect", - ]; - - let constructor = "CreateEraseDeadFunctionsPass()"; -} - def VectorizeLoadsAndStoresPass : Pass<"xla-gpu-vectorize-loads-stores", "mlir::func::FuncOp"> { let summary = "Vectorizes loads and stores."; @@ -235,17 +101,6 @@ def FuseLoopsPass : Pass<"xla-gpu-fuse-loops", "mlir::func::FuncOp"> { let constructor = "CreateFuseLoopsPass()"; } -def PeelLoopsPass : Pass<"xla-gpu-peel-loops", "mlir::func::FuncOp"> { - let summary = "Peels xla_gpu.loop."; - let description = [{ - Attempts to split each loop dimension [0, NUM_ITERATIONS) - as [0, NUM_ITERATIONS - 1) and [NUM_ITERATIONS - 1, NUM_ITERATIONS) - if it removes a constraint. - }]; - let dependentDialects = ["xla::gpu::XlaGpuDialect", "xla::XlaDialect"]; - let constructor = "CreatePeelLoopsPass()"; -} - def OptimizeLoopsPass : Pass<"xla-gpu-optimize-loops", "mlir::func::FuncOp"> { let summary = "Unrolls and pipelines loops."; @@ -264,52 +119,4 @@ def OptimizeLoopsPass : let constructor = "CreateOptimizeLoopsPass()"; } -def UnswitchLoopsPass : - Pass<"xla-gpu-unswitch-loops", "mlir::func::FuncOp"> { - let summary = "Swaps scf.if and scf.for."; - - let description = [{ - Extracts `scf.if` ops with conditions that are independent of the loop - variable from `scf.for` by doing the following rewrite: - - Before: - - %cond = some_cond() : i1 - %results = scf.for { - %some_val = scf.if %cond { - } else { - } - scf.yield %some_val - } - - After: - - %cond = some_cond() : i1 - %results = scf.if %cond { - %results = scf.for { - %some_val = scf.if %true { - } else { - } - } - yield %results - } else { - %results = scf.for { - %some_val = scf.if %false { - } else { - } - } - yield %results - } - - This only triggers if there is a single `scf.if` op in the loop body (and - nothing else). - }]; - - let dependentDialects = [ - "mlir::func::FuncDialect", "mlir::scf::SCFDialect" - ]; - - let constructor = "CreateUnswitchLoopsPass()"; -} - #endif // XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_TD_ diff --git a/xla/backends/gpu/codegen/emitters/transforms/tests/inlining.mlir b/xla/backends/gpu/codegen/emitters/transforms/tests/inlining.mlir deleted file mode 100644 index fac0ff321778f..0000000000000 --- a/xla/backends/gpu/codegen/emitters/transforms/tests/inlining.mlir +++ /dev/null @@ -1,295 +0,0 @@ -// RUN: emitters_opt %s -split-input-file -xla-erase-dead-functions -inline | FileCheck %s - -module { - func.func private @mul(%a: f32, %b: f32) -> f32 { - %ret = arith.mulf %a, %b : f32 - return %ret : f32 - } - - func.func private @add(%a: f32, %b: f32) -> f32 { - %add = arith.addf %a, %b : f32 - %ret = xla.pure_call @mul(%add, %add) : (f32, f32) -> (f32) - return %ret : f32 - } - - func.func @caller(%a: f32, %b: f32) -> f32 { - %ret = xla.pure_call @add(%a, %b) : (f32, f32) -> (f32) - return %ret : f32 - } -} - -// CHECK-LABEL: module { -// CHECK: @caller -// CHECK-NOT: xla.pure_call @add -// CHECK: arith.addf -// CHECK-NOT: xla.pure_call @mul -// CHECK: arith.mulf - -// ----- - -module { - func.func @fused_computation(%arg0: tensor<2xf32> {xla.slice_index = 0 : index}, %arg1: tensor<2xf32> {xla.slice_index = 1 : index}, %arg2: tensor<2xf32> {xla.slice_index = 2 : index}) -> tensor<2xf32> attributes {xla.entry} { - %0 = gpu.thread_id x {xla.range = [0 : index, 1 : index]} - %1 = xla.pure_call @fused_computation_atan2(%arg0, %arg1, %0) : (tensor<2xf32>, tensor<2xf32>, index) -> f32 - %inserted = tensor.insert %1 into %arg2[%0] : tensor<2xf32> - return %inserted : tensor<2xf32> - } - func.func private @fused_computation_atan2(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>, %arg2: index {xla.range = [0 : index, 1 : index]}) -> f32 attributes {llvm.linkage = #llvm.linkage} { - %extracted = tensor.extract %arg0[%arg2] : tensor<2xf32> - %extracted_0 = tensor.extract %arg1[%arg2] : tensor<2xf32> - %0 = arith.addf %extracted, %extracted_0 : f32 - %1 = arith.subf %extracted, %extracted_0 : f32 - %2 = arith.mulf %0, %1 : f32 - %3 = arith.divf %0, %1 : f32 - %4 = math.atan2 %2, %3 : f32 - return %4 : f32 - } -} - -// CHECK-LABEL: module { -// CHECK: @fused_computation -// CHECK-NOT: xla.pure_call @add -// CHECK: gpu.thread_id -// CHECK-NEXT: tensor.extract -// CHECK-NEXT: tensor.extract -// CHECK-NEXT: arith.addf -// CHECK-NEXT: arith.subf -// CHECK-NEXT: arith.mulf -// CHECK-NEXT: arith.divf -// CHECK-NEXT: math.atan2 -// CHECK-NEXT: tensor.insert - -// ----- - -module { - // Do not inline this function as it has two callers. Even if the callers are - // in different functions at the start, after inlining the two callers are in - // the same function. - func.func private @large(%a: f32, %b: f32) -> f32 { - %mul = arith.mulf %a, %b : f32 - %add = arith.addf %a, %mul : f32 - %div = arith.divf %add, %b : f32 - %sub = arith.subf %div, %a : f32 - %atan2 = math.atan2 %b, %sub : f32 - %neg = arith.negf %atan2 : f32 - %zero = arith.constant 0.0 : f32 - %comp = arith.cmpf olt, %neg, %zero : f32 - %ret = arith.select %comp, %zero, %neg : f32 - return %ret : f32 - } - - func.func private @add(%a: f32, %b: f32) -> f32 { - %add = arith.addf %a, %b : f32 - %ret = xla.pure_call @large(%add, %add) : (f32, f32) -> (f32) - return %ret : f32 - } - - func.func @caller(%a: f32, %b: f32) -> f32 { - %add = xla.pure_call @add(%a, %b) : (f32, f32) -> (f32) - %ret = xla.pure_call @large(%add, %add) : (f32, f32) -> (f32) - return %ret : f32 - } -} - -// CHECK-LABEL: module { -// CHECK: @caller -// CHECK: arith.addf -// CHECK: xla.pure_call @large -// CHECK: xla.pure_call @large - -// ----- - -module { - func.func private @add(%a: f32, %b: f32) -> f32 { - %ret = arith.addf %a, %b : f32 - return %ret : f32 - } - - func.func @caller(%a: f32, %b: f32) -> f32 { - %add = xla.pure_call @add(%a, %b) : (f32, f32) -> (f32) - %ret = xla.pure_call @add(%add, %add) : (f32, f32) -> (f32) - return %ret : f32 - } -} - -// CHECK-LABEL: module { -// CHECK: @caller -// CHECK-NOT: xla.pure_call -// CHECK: arith.addf -// CHECK: arith.addf - -// ----- - -module { - func.func private @fib0(%start : f32) -> f32 { - %zero = arith.constant 0.0 : f32 - return %zero : f32 - } - func.func private @fib1(%start : f32) -> f32 { - return %start : f32 - } - func.func private @fib2(%start : f32) -> f32 { - %a = xla.pure_call @fib0(%start) : (f32) -> (f32) - %b = xla.pure_call @fib1(%start) : (f32) -> (f32) - %ret = arith.addf %a, %b : f32 - return %ret : f32 - } - func.func private @fib3(%start : f32) -> f32 { - %a = xla.pure_call @fib1(%start) : (f32) -> (f32) - %b = xla.pure_call @fib2(%start) : (f32) -> (f32) - %ret = arith.addf %a, %b : f32 - return %ret : f32 - } - func.func private @fib4(%start : f32) -> f32 { - %a = xla.pure_call @fib2(%start) : (f32) -> (f32) - %b = xla.pure_call @fib3(%start) : (f32) -> (f32) - %ret = arith.addf %a, %b : f32 - return %ret : f32 - } - // When inlining the other functions into @fib5, this function exceeds the - // threshold for inlining. - func.func private @fib5(%start : f32) -> f32 { - %a = xla.pure_call @fib3(%start) : (f32) -> (f32) - %b = xla.pure_call @fib4(%start) : (f32) -> (f32) - %ret = arith.addf %a, %b : f32 - return %ret : f32 - } - // As we do not inline @fib5 into @fib6, this function stays below the - // threshold for inlining. - func.func private @fib6(%start : f32) -> f32 { - %a = xla.pure_call @fib4(%start) : (f32) -> (f32) - %b = xla.pure_call @fib5(%start) : (f32) -> (f32) - %ret = arith.addf %a, %b : f32 - return %ret : f32 - } - func.func private @fib7(%start : f32) -> f32 { - %a = xla.pure_call @fib5(%start) : (f32) -> (f32) - %b = xla.pure_call @fib6(%start) : (f32) -> (f32) - %ret = arith.addf %a, %b : f32 - return %ret : f32 - } - - func.func @caller(%a: f32) -> f32 { - %ret = xla.pure_call @fib7(%a) : (f32) -> (f32) - return %ret : f32 - } -} - -// CHECK-LABEL: module { -// CHECK: @caller -// CHECK: arith.constant 0.000000e+00 -// CHECK: xla.pure_call @fib5 -// CHECK: arith.addf -// CHECK: arith.addf -// CHECK: arith.addf -// CHECK: arith.addf -// CHECK: xla.pure_call @fib5 -// CHECK: arith.addf -// CHECK: arith.addf - -// ----- - -module { - func.func private @complex(%a: f32, %b: f32) -> complex { - %ret = complex.create %a, %b : complex - return %ret : complex - } - - func.func @caller(%a: f32, %b: f32) -> complex { - %ret = xla.pure_call @complex(%a, %b) : (f32, f32) -> (complex) - return %ret : complex - } -} - -// CHECK-LABEL: module { -// CHECK: @caller -// CHECK-NEXT: complex.create - -// ----- - -module { - func.func private @callee2(%a: f32) -> f32 { - %ret = arith.addf %a, %a : f32 - return %ret : f32 - } - - func.func private @callee1(%a: f32) -> f32 { - %c1 = xla.pure_call @callee2(%a) : (f32) -> (f32) - %b0 = arith.addf %a, %a : f32 - %b1 = arith.addf %b0, %a : f32 - %b2 = arith.addf %b1, %a : f32 - %b3 = arith.addf %b2, %a : f32 - %b4 = arith.addf %b3, %a : f32 - %b5 = arith.addf %b4, %a : f32 - %b6 = arith.addf %b5, %a : f32 - %b7 = arith.addf %b6, %a : f32 - %c2 = xla.pure_call @callee2(%b7) : (f32) -> (f32) - %ret = arith.addf %c1, %c2 : f32 - return %ret : f32 - } - - func.func private @dead(%a: f32) -> f32 { - %ret = xla.pure_call @callee1(%a) : (f32) -> (f32) - return %ret : f32 - } - - func.func @caller(%a: f32, %b: f32) -> f32 { - %ret = xla.pure_call @callee1(%a) : (f32) -> (f32) - return %ret : f32 - } -} - -// CHECK-LABEL: module { -// CHECK-NOT: func.func -// CHECK: func.func @caller -// CHECK-NOT: xla.pure_call -// CHECK-NOT: func.func - -// ----- - -module { - func.func private @callee1(%a: f32) -> f32 { - %b0 = arith.addf %a, %a : f32 - %b1 = arith.addf %b0, %a : f32 - %b2 = arith.addf %b1, %a : f32 - %b3 = arith.addf %b2, %a : f32 - %b4 = arith.addf %b3, %a : f32 - %b5 = arith.addf %b4, %a : f32 - %b6 = arith.addf %b5, %a : f32 - %b7 = arith.addf %b6, %a : f32 - %b8 = arith.addf %b7, %a : f32 - %b9 = arith.addf %b8, %a : f32 - %b10 = arith.addf %b9, %a : f32 - %b11 = arith.addf %b10, %a : f32 - return %b11 : f32 - } - - func.func private @callee2(%a: f32) -> f32 { - %call = xla.pure_call @callee1(%a) : (f32) -> (f32) - %b0 = arith.addf %a, %a : f32 - %b1 = arith.addf %b0, %a : f32 - %b2 = arith.addf %b1, %a : f32 - %b3 = arith.addf %b2, %a : f32 - %b4 = arith.addf %b3, %a : f32 - %b5 = arith.addf %b4, %a : f32 - %b6 = arith.addf %b5, %a : f32 - %b7 = arith.addf %b6, %a : f32 - %b8 = arith.addf %b7, %a : f32 - %b9 = arith.addf %b8, %a : f32 - %ret = arith.addf %call, %b9 : f32 - return %ret : f32 - } - - func.func @caller(%a: f32, %b: f32) -> f32 { - %call1 = xla.pure_call @callee2(%a) : (f32) -> (f32) - %call2 = xla.pure_call @callee1(%a) : (f32) -> (f32) - %ret = arith.addf %call1, %call2 : f32 - return %ret : f32 - } -} - -// CHECK-LABEL: module { -// CHECK: func.func private @callee1 -// CHECK-NOT: callee2 -// CHECK: func.func @caller -// CHECK-COUNT-2: pure_call @callee1 diff --git a/xla/backends/gpu/codegen/triton/BUILD b/xla/backends/gpu/codegen/triton/BUILD index 2a1e3ebf88d4e..7878238fae0f3 100644 --- a/xla/backends/gpu/codegen/triton/BUILD +++ b/xla/backends/gpu/codegen/triton/BUILD @@ -200,6 +200,7 @@ cc_library( "//xla/codegen:emitter_loc_op_builder", "//xla/codegen/emitters:elemental_hlo_to_mlir", "//xla/codegen/emitters/ir:xla", + "//xla/codegen/emitters/transforms:passes", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer", diff --git a/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/xla/backends/gpu/codegen/triton/fusion_emitter.cc index 0f5a99648ece5..9703836ce0795 100644 --- a/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -91,6 +91,7 @@ limitations under the License. #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/codegen/emitters/elemental_hlo_to_mlir.h" #include "xla/codegen/emitters/ir/xla_ops.h" +#include "xla/codegen/emitters/transforms/passes.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -1290,7 +1291,7 @@ absl::StatusOr CompileTritonToLLVM( pm.addPass(mlir::createLowerAffinePass()); // Lower xla_gpu.apply_indexing into arithmetic ops. - pm.addPass(CreateSimplifyAffinePass()); + pm.addPass(emitters::CreateSimplifyAffinePass()); pm.addPass(CreateConvertIndexTypePass()); mlir::triton::nvidia_gpu::ClusterInfo cluster_info; diff --git a/xla/codegen/emitters/BUILD b/xla/codegen/emitters/BUILD index 478e7cf6a707b..8050c3b313dec 100644 --- a/xla/codegen/emitters/BUILD +++ b/xla/codegen/emitters/BUILD @@ -1,4 +1,5 @@ load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/xla/codegen/emitters/ir/BUILD b/xla/codegen/emitters/ir/BUILD index 18323597ff83b..71ff9e93ccf3c 100644 --- a/xla/codegen/emitters/ir/BUILD +++ b/xla/codegen/emitters/ir/BUILD @@ -1,5 +1,7 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable") load("//xla/tests:build_defs.bzl", "xla_test") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -17,6 +19,7 @@ package_group( td_library( name = "xla_td_files", srcs = glob(["*.td"]), + compatible_with = get_compatible_with_portable(), includes = ["."], deps = [ "@llvm-project//mlir:BuiltinDialectTdFiles", @@ -30,6 +33,7 @@ td_library( gentbl_cc_library( name = "xla_dialect_inc_gen", + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -48,6 +52,7 @@ gentbl_cc_library( gentbl_cc_library( name = "xla_ops_inc_gen", + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( @@ -66,6 +71,7 @@ gentbl_cc_library( gentbl_cc_library( name = "xla_attrs_inc_gen", + compatible_with = get_compatible_with_portable(), strip_include_prefix = ".", tbl_outs = [ ( diff --git a/xla/codegen/emitters/transforms/BUILD b/xla/codegen/emitters/transforms/BUILD index b5ca6bbb548d2..acee4df1cdfa1 100644 --- a/xla/codegen/emitters/transforms/BUILD +++ b/xla/codegen/emitters/transforms/BUILD @@ -1,4 +1,6 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -29,6 +31,7 @@ cc_library( gentbl_cc_library( name = "passes_inc_gen", + compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ @@ -47,10 +50,19 @@ gentbl_cc_library( cc_library( name = "passes", srcs = [ + "convert_pure_call_ops.cc", + "erase_dead_functions.cc", "expand_float_ops.cc", "flatten_tensors.cc", "lower_tensors.cc", "lower_to_llvm.cc", + "lower_xla_to_scf.cc", + "merge_pointers_to_same_slice.cc", + "peel_loops.cc", + "propagate_slice_indices.cc", + "simplify_affine.cc", + "simplify_arith.cc", + "unswitch_loops.cc", ], hdrs = ["passes.h"], deps = [ @@ -62,20 +74,29 @@ cc_library( "//xla/backends/cpu/codegen/emitters/ir:xla_cpu", "//xla/backends/gpu/codegen/emitters/ir:xla_gpu", "//xla/codegen:device_spec", + "//xla/codegen/emitters:elemental_hlo_to_mlir", + "//xla/codegen/emitters/ir:xla", "//xla/hlo/analysis:indexing_analysis", "//xla/mlir_hlo", "//xla/mlir_hlo:map_mhlo_to_scalar_op", + "//xla/service/gpu:ir_emission_utils", "//xla/stream_executor:device_description", "//xla/tsl/platform:logging", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", + "@llvm-project//mlir:AffineUtils", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:ArithTransforms", + "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:ComplexToLLVM", "@llvm-project//mlir:ControlFlowToLLVM", diff --git a/xla/backends/gpu/codegen/emitters/transforms/convert_xla_gpu_pure_call_ops.cc b/xla/codegen/emitters/transforms/convert_pure_call_ops.cc similarity index 91% rename from xla/backends/gpu/codegen/emitters/transforms/convert_xla_gpu_pure_call_ops.cc rename to xla/codegen/emitters/transforms/convert_pure_call_ops.cc index acf398e5a4149..4990728395e3b 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/convert_xla_gpu_pure_call_ops.cc +++ b/xla/codegen/emitters/transforms/convert_pure_call_ops.cc @@ -17,14 +17,14 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/backends/gpu/codegen/emitters/ir/xla_gpu_ops.h" +#include "xla/codegen/emitters/ir/xla_ops.h" namespace xla { -namespace gpu { +namespace emitters { namespace { #define GEN_PASS_DEF_CONVERTPURECALLOPSPASS -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h.inc" +#include "xla/codegen/emitters/transforms/passes.h.inc" struct RewriteCall : mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -57,5 +57,5 @@ std::unique_ptr<::mlir::Pass> CreateConvertPureCallOpsPass() { return std::make_unique(); } -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/xla/backends/gpu/codegen/emitters/transforms/erase_dead_functions.cc b/xla/codegen/emitters/transforms/erase_dead_functions.cc similarity index 92% rename from xla/backends/gpu/codegen/emitters/transforms/erase_dead_functions.cc rename to xla/codegen/emitters/transforms/erase_dead_functions.cc index 3b20ec28a7b82..9170fdf6b4a5d 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/erase_dead_functions.cc +++ b/xla/codegen/emitters/transforms/erase_dead_functions.cc @@ -21,18 +21,18 @@ limitations under the License. #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" -#include "xla/backends/gpu/codegen/emitters/ir/xla_gpu_ops.h" +#include "xla/codegen/emitters/ir/xla_ops.h" namespace xla { -namespace gpu { +namespace emitters { #define GEN_PASS_DEF_ERASEDEADFUNCTIONSPASS -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h.inc" +#include "xla/codegen/emitters/transforms/passes.h.inc" namespace { struct CallInfo { - PureCallOp call; + xla::PureCallOp call; int count; }; @@ -82,5 +82,5 @@ CreateEraseDeadFunctionsPass() { return std::make_unique(); } -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/xla/backends/gpu/codegen/emitters/transforms/lower_xla_gpu_to_scf.cc b/xla/codegen/emitters/transforms/lower_xla_to_scf.cc similarity index 89% rename from xla/backends/gpu/codegen/emitters/transforms/lower_xla_gpu_to_scf.cc rename to xla/codegen/emitters/transforms/lower_xla_to_scf.cc index 5cd33de519a97..06d98d3d82b9b 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/lower_xla_gpu_to_scf.cc +++ b/xla/codegen/emitters/transforms/lower_xla_to_scf.cc @@ -43,19 +43,20 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "xla/backends/gpu/codegen/emitters/ir/xla_gpu_ops.h" -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h" #include "xla/codegen/emitters/elemental_hlo_to_mlir.h" +#include "xla/codegen/emitters/ir/xla_ops.h" +#include "xla/codegen/emitters/transforms/passes.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/util.h" namespace xla { -namespace gpu { +namespace emitters { namespace { -#define GEN_PASS_DEF_LOWERXLAGPUTOSCFPASS -#define GEN_PASS_DEF_LOWERXLAGPULOOPSTOSCFPASS -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h.inc" +#define GEN_PASS_DEF_LOWERXLATOSCFPASS +#define GEN_PASS_DEF_LOWERXLALOOPSTOSCFPASS +#include "xla/codegen/emitters/transforms/passes.h.inc" using mlir::ImplicitLocOpBuilder; using mlir::Location; @@ -68,7 +69,7 @@ using mlir::scf::IfOp; struct RewritePredicatedInsert : mlir::OpRewritePattern { RewritePredicatedInsert(mlir::MLIRContext* context, - const LowerXlaGpuToScfPassOptions& options) + const LowerXlaToScfPassOptions& options) : OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( @@ -90,7 +91,7 @@ struct RewritePredicatedInsert : mlir::OpRewritePattern { struct RewritePredicatedExtract : mlir::OpRewritePattern { RewritePredicatedExtract(mlir::MLIRContext* context, - const LowerXlaGpuToScfPassOptions& options) + const LowerXlaToScfPassOptions& options) : OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( @@ -110,15 +111,15 @@ struct RewritePredicatedExtract : mlir::OpRewritePattern { } }; -struct RewriteShuffleReduce : mlir::OpRewritePattern { +struct RewriteShuffleReduce : mlir::OpRewritePattern { const int64_t warp_size; RewriteShuffleReduce(mlir::MLIRContext* context, - const LowerXlaGpuToScfPassOptions& options) + const LowerXlaToScfPassOptions& options) : OpRewritePattern(context), warp_size(options.warp_size) {} mlir::LogicalResult matchAndRewrite( - ShuffleReduceOp op, mlir::PatternRewriter& rewriter) const override { + gpu::ShuffleReduceOp op, mlir::PatternRewriter& rewriter) const override { int max_distance = mlir::cast(op->getAttr("max_distance")).getInt(); // TODO(jreiffers): Do this in a verifier. @@ -203,7 +204,7 @@ struct RewriteShuffleReduce : mlir::OpRewritePattern { } }; -struct RewriteXlaGpuLoop : mlir::OpRewritePattern { +struct RewriteXlaLoop : mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite( @@ -251,7 +252,8 @@ struct RewriteXlaGpuLoop : mlir::OpRewritePattern { } }; -mlir::VectorType getThreadLevelVectorType(IndexedVectorType indexed_vector) { +mlir::VectorType getThreadLevelVectorType( + gpu::IndexedVectorType indexed_vector) { auto data_type = indexed_vector.getElementType(); SmallVector vector_dims; if (auto complex = mlir::dyn_cast(data_type)) { @@ -265,13 +267,13 @@ mlir::VectorType getThreadLevelVectorType(IndexedVectorType indexed_vector) { return mlir::VectorType::get(vector_dims, data_type); } -struct RewriteMaterialize : mlir::OpRewritePattern { +struct RewriteMaterialize : mlir::OpRewritePattern { RewriteMaterialize(mlir::MLIRContext* context, - const LowerXlaGpuToScfPassOptions& options) + const LowerXlaToScfPassOptions& options) : OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( - MaterializeOp op, mlir::PatternRewriter& rewriter) const override { + gpu::MaterializeOp op, mlir::PatternRewriter& rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto i0 = b.create(0); auto i1 = b.create(1); @@ -324,13 +326,13 @@ struct RewriteMaterialize : mlir::OpRewritePattern { } }; -struct RewriteInsert : mlir::OpRewritePattern { +struct RewriteInsert : mlir::OpRewritePattern { RewriteInsert(mlir::MLIRContext* context, - const LowerXlaGpuToScfPassOptions& options) + const LowerXlaToScfPassOptions& options) : OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( - InsertOp op, mlir::PatternRewriter& rewriter) const override { + gpu::InsertOp op, mlir::PatternRewriter& rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto i0 = b.create(0); auto i1 = b.create(1); @@ -376,10 +378,10 @@ struct RewriteInsert : mlir::OpRewritePattern { } }; -class LowerXlaGpuToScfPass - : public impl::LowerXlaGpuToScfPassBase { +class LowerXlaToScfPass + : public impl::LowerXlaToScfPassBase { public: - explicit LowerXlaGpuToScfPass(const LowerXlaGpuToScfPassOptions& options) + explicit LowerXlaToScfPass(const LowerXlaToScfPassOptions& options) : options_(options) {} void runOnOperation() override { @@ -395,16 +397,16 @@ class LowerXlaGpuToScfPass } private: - const LowerXlaGpuToScfPassOptions options_; + const LowerXlaToScfPassOptions options_; }; -class LowerXlaGpuLoopsToScfPass - : public impl::LowerXlaGpuLoopsToScfPassBase { +class LowerXlaLoopsToScfPass + : public impl::LowerXlaLoopsToScfPassBase { public: void runOnOperation() override { auto* ctx = &getContext(); mlir::RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns.add(ctx); if (mlir::failed( mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); @@ -414,16 +416,15 @@ class LowerXlaGpuLoopsToScfPass } // namespace -std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuToScfPass( - const int64_t warp_size) { - LowerXlaGpuToScfPassOptions options; +std::unique_ptr<::mlir::Pass> CreateLowerXlaToScfPass(const int64_t warp_size) { + LowerXlaToScfPassOptions options; options.warp_size = warp_size; - return std::make_unique(options); + return std::make_unique(options); } -std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuLoopsToScfPass() { - return std::make_unique(); +std::unique_ptr<::mlir::Pass> CreateLowerXlaLoopsToScfPass() { + return std::make_unique(); } -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/xla/backends/gpu/codegen/emitters/transforms/merge_pointers_to_same_slice.cc b/xla/codegen/emitters/transforms/merge_pointers_to_same_slice.cc similarity index 97% rename from xla/backends/gpu/codegen/emitters/transforms/merge_pointers_to_same_slice.cc rename to xla/codegen/emitters/transforms/merge_pointers_to_same_slice.cc index d4eace906ca14..aa7375ad239c9 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/merge_pointers_to_same_slice.cc +++ b/xla/codegen/emitters/transforms/merge_pointers_to_same_slice.cc @@ -27,10 +27,10 @@ limitations under the License. #include "mlir/Support/LLVM.h" namespace xla { -namespace gpu { +namespace emitters { #define GEN_PASS_DEF_MERGEPOINTERSTOSAMESLICEPASS -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h.inc" +#include "xla/codegen/emitters/transforms/passes.h.inc" namespace { @@ -113,5 +113,5 @@ CreateMergePointersToSameSlicePass() { return std::make_unique(); } -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/xla/codegen/emitters/transforms/passes.h b/xla/codegen/emitters/transforms/passes.h index 51c469041048c..e8ff284419403 100644 --- a/xla/codegen/emitters/transforms/passes.h +++ b/xla/codegen/emitters/transforms/passes.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef XLA_CODEGEN_EMITTERS_TRANSFORMS_PASSES_H_ #define XLA_CODEGEN_EMITTERS_TRANSFORMS_PASSES_H_ +#include #include #include @@ -27,6 +28,8 @@ namespace emitters { #define GEN_PASS_DECL #include "xla/codegen/emitters/transforms/passes.h.inc" +std::unique_ptr CreateConvertPureCallOpsPass(); +std::unique_ptr CreateEraseDeadFunctionsPass(); std::unique_ptr CreateExpandFloatOpsPass(); std::unique_ptr CreateFlattenTensorsPass(); std::unique_ptr CreateLowerTensorsPass( @@ -39,6 +42,14 @@ std::unique_ptr CreateLowerToLLVMPass( const std::string& gpu_device_info = ""); std::unique_ptr CreateLowerToLLVMPass( const stream_executor::DeviceDescription& device_description); +std::unique_ptr CreateLowerXlaToScfPass(int64_t warp_size = 32); +std::unique_ptr CreateLowerXlaLoopsToScfPass(); +std::unique_ptr CreateMergePointersToSameSlicePass(); +std::unique_ptr CreatePeelLoopsPass(); +std::unique_ptr CreatePropagateSliceIndicesPass(); +std::unique_ptr CreateSimplifyAffinePass(); +std::unique_ptr CreateSimplifyArithPass(); +std::unique_ptr CreateUnswitchLoopsPass(); #define GEN_PASS_REGISTRATION #include "xla/codegen/emitters/transforms/passes.h.inc" diff --git a/xla/codegen/emitters/transforms/passes.td b/xla/codegen/emitters/transforms/passes.td index 71d5e627ee3c8..488b3f457e8ef 100644 --- a/xla/codegen/emitters/transforms/passes.td +++ b/xla/codegen/emitters/transforms/passes.td @@ -18,6 +18,36 @@ limitations under the License. include "mlir/Pass/PassBase.td" +def ConvertPureCallOpsPass + : Pass<"xla-convert-pure-call-ops", "mlir::func::FuncOp"> { + let summary = "Converts xla.pure_call to func.call"; + let description = [{ + We use xla.pure_call ops for calls to enable CSE and other + transformations (e.g. LICM). This pass rewrites our custom ops to standard + ops. + }]; + let dependentDialects = [ + "mlir::func::FuncDialect", + "xla::XlaDialect" + ]; + let constructor = "CreateConvertPureCallOpsPass()"; +} + +def EraseDeadFunctionsPass : Pass<"xla-erase-dead-functions", "mlir::ModuleOp"> { + let summary = "Deletes unused functions"; + + let description = [{ + Deletes functions that are not called. + }]; + + let dependentDialects = [ + "mlir::func::FuncDialect", + "xla::XlaDialect", + ]; + + let constructor = "CreateEraseDeadFunctionsPass()"; +} + def FlattenTensorsPass : Pass<"xla-flatten-tensors", "mlir::ModuleOp"> { let summary = "Flatten tensors."; @@ -34,6 +64,24 @@ def FlattenTensorsPass : Pass<"xla-flatten-tensors", "mlir::ModuleOp"> { let constructor = "CreateFlattenTensorsPass()"; } +def ExpandFloatOpsPass : Pass<"xla-expand-float-ops", "mlir::ModuleOp"> { + let summary = "Expands float ops that are not natively supported."; + + let description = [{ + Not all float ops are natively supported, either because they don't exist + in hardware or they are too inaccurate. + + This pass replaces these ops with alternative implementations. + }]; + + let dependentDialects = [ + "mlir::arith::ArithDialect", "mlir::math::MathDialect", + "mlir::mhlo::MhloDialect" + ]; + + let constructor = "CreateExpandFloatOpsPass()"; +} + def LowerTensorsPass : Pass<"xla-lower-tensors", "mlir::ModuleOp"> { let summary = "Lowers tensors to llvm pointers and loads/stores."; @@ -61,24 +109,6 @@ def LowerTensorsPass : Pass<"xla-lower-tensors", "mlir::ModuleOp"> { let constructor = "CreateLowerTensorsPass()"; } -def ExpandFloatOpsPass : Pass<"xla-expand-float-ops", "mlir::ModuleOp"> { - let summary = "Expands float ops that are not natively supported."; - - let description = [{ - Not all float ops are natively supported, either because they don't exist - in hardware or they are too inaccurate. - - This pass replaces these ops with alternative implementations. - }]; - - let dependentDialects = [ - "mlir::arith::ArithDialect", "mlir::math::MathDialect", - "mlir::mhlo::MhloDialect" - ]; - - let constructor = "CreateExpandFloatOpsPass()"; -} - def LowerToLLVMPass : Pass<"xla-lower-to-llvm", "mlir::ModuleOp"> { let summary = "Lowers to LLVM."; @@ -102,4 +132,166 @@ def LowerToLLVMPass : let constructor = "CreateLowerToLLVMPass()"; } +def LowerXlaToScfPass : + Pass<"xla-lower-xla-to-scf", "mlir::func::FuncOp"> { + let summary = "Lowers xla to SCF."; + + let dependentDialects = [ + "mlir::gpu::GPUDialect", "mlir::LLVM::LLVMDialect", "mlir::scf::SCFDialect", + "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect", + "xla::XlaDialect", "mlir::vector::VectorDialect", + ]; + + let options = [ + Option<"warp_size", "warp_size", "int64_t", /*default=*/"32", "Warp size.">, + ]; + let constructor = "CreateLowerXlaToScfPass()"; +} + +def LowerXlaLoopsToScfPass : Pass< + "xla-lower-xla-loops-to-scf", "mlir::func::FuncOp"> { + let summary = "Lowers xla.loop to SCF."; + + let description = [{ + This pass is separate from xla-lower-xla-to-scf because + xla-lower-xla-to-scf, inliner, peeling and xla-lower-xla-loops-to-scf + have to run in that order. + }]; + + let dependentDialects = [ + "mlir::scf::SCFDialect", + "mlir::tensor::TensorDialect", + "xla::gpu::XlaGpuDialect", + "xla::XlaDialect", + ]; + + let constructor = "CreateLowerXlaLoopsToScfPass()"; +} + +def MergePointersToSameSlicePass : + Pass<"xla-merge-pointers", "mlir::ModuleOp"> { + let summary = "Merges pointers that share slices."; + + let description = [{ + When a function has multiple pointer arguments with the same slice index, + merges them. + }]; + + let dependentDialects = [ + "mlir::func::FuncDialect" + ]; + + let constructor = "CreateMergePointersToSameSlicePass()"; +} + +def PeelLoopsPass : Pass<"xla-peel-loops", "mlir::func::FuncOp"> { + let summary = "Peels xla.loop."; + let description = [{ + Attempts to split each loop dimension [0, NUM_ITERATIONS) + as [0, NUM_ITERATIONS - 1) and [NUM_ITERATIONS - 1, NUM_ITERATIONS) + if it removes a constraint. + }]; + let dependentDialects = ["xla::XlaDialect"]; + let constructor = "CreatePeelLoopsPass()"; +} + +def PropagateSliceIndicesPass : + Pass<"xla-propagate-slice-indices", "mlir::ModuleOp"> { + let summary = "Propagates slice indices from the entry function to all callees."; + + let description = [{ + Propagates xla.slice_index attributes from the function with the xla.entry + attribute to all other functions. + }]; + + let dependentDialects = [ + "mlir::func::FuncDialect" + ]; + + let constructor = "CreatePropagateSliceIndicesPass()"; +} + +def SimplifyAffinePass : Pass<"xla-simplify-affine", "mlir::ModuleOp"> { + let summary = "Simplifies affine.apply using XLA's range-aware simplifier."; + + let description = [{ + The standard affine canonicalizer cannot simplify all expressions, since + it is unaware of range information. This pass uses `xla.range` attributes + on arguments and ops for simplification. It also lowers floordiv and mod + to simpler expressions than lower-affine. This pass only works for + expressions for which we can prove the LHS of mod and div is nonnegative. + }]; + + let dependentDialects = [ + "mlir::affine::AffineDialect", "mlir::func::FuncDialect", + "mlir::scf::SCFDialect", + ]; + + let constructor = "CreateSimplifyAffinePass()"; +} + +def SimplifyArithPass : Pass<"xla-simplify-arith", "mlir::func::FuncOp"> { + let summary = "Simplifies arith using XLA's range-aware simplifier."; + + let description = [{ + We often emit bounds checks that are statically known to be satisfied. + This pass removes them. + }]; + + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::func::FuncDialect", + ]; + + let constructor = "CreateSimplifyArithPass()"; +} + +def UnswitchLoopsPass : + Pass<"xla-unswitch-loops", "mlir::func::FuncOp"> { + let summary = "Swaps scf.if and scf.for."; + + let description = [{ + Extracts `scf.if` ops with conditions that are independent of the loop + variable from `scf.for` by doing the following rewrite: + + Before: + + %cond = some_cond() : i1 + %results = scf.for { + %some_val = scf.if %cond { + } else { + } + scf.yield %some_val + } + + After: + + %cond = some_cond() : i1 + %results = scf.if %cond { + %results = scf.for { + %some_val = scf.if %true { + } else { + } + } + yield %results + } else { + %results = scf.for { + %some_val = scf.if %false { + } else { + } + } + yield %results + } + + This only triggers if there is a single `scf.if` op in the loop body (and + nothing else). + }]; + + let dependentDialects = [ + "mlir::func::FuncDialect", "mlir::scf::SCFDialect" + ]; + + let constructor = "CreateUnswitchLoopsPass()"; +} + #endif // XLA_CODEGEN_EMITTERS_TRANSFORMS_PASSES_TD_ diff --git a/xla/backends/gpu/codegen/emitters/transforms/peel_loops.cc b/xla/codegen/emitters/transforms/peel_loops.cc similarity index 96% rename from xla/backends/gpu/codegen/emitters/transforms/peel_loops.cc rename to xla/codegen/emitters/transforms/peel_loops.cc index de8810182b787..a7cecd6a6cd23 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/peel_loops.cc +++ b/xla/codegen/emitters/transforms/peel_loops.cc @@ -32,16 +32,16 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/backends/gpu/codegen/emitters/ir/xla_gpu_ops.h" +#include "xla/codegen/emitters/ir/xla_ops.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/analysis/indexing_map_serialization.h" namespace xla { -namespace gpu { +namespace emitters { namespace { #define GEN_PASS_DEF_PEELLOOPSPASS -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h.inc" +#include "xla/codegen/emitters/transforms/passes.h.inc" using mlir::Location; using mlir::OpBuilder; @@ -63,7 +63,7 @@ struct PeelLoop : public OpRewritePattern { auto indexing_map = loop_op.getIndexingMap(); // TODO(b/358274367): Remove the simplify call once we have `is_simplified` // field and a canonicalization pattern to simplify indexing map in - // xla_gpu.loop. + // xla.loop. indexing_map.Simplify(); SmallVector indexing_maps{indexing_map}; for (int sym_index = indexing_map.GetSymbolCount() - 1; @@ -146,5 +146,5 @@ std::unique_ptr CreatePeelLoopsPass() { return std::make_unique(); } -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/xla/backends/gpu/codegen/emitters/transforms/propagate_slice_indices.cc b/xla/codegen/emitters/transforms/propagate_slice_indices.cc similarity index 92% rename from xla/backends/gpu/codegen/emitters/transforms/propagate_slice_indices.cc rename to xla/codegen/emitters/transforms/propagate_slice_indices.cc index dde5b092b812f..f742dbea9acc9 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/propagate_slice_indices.cc +++ b/xla/codegen/emitters/transforms/propagate_slice_indices.cc @@ -19,13 +19,13 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h" +#include "xla/codegen/emitters/transforms/passes.h" namespace xla { -namespace gpu { +namespace emitters { #define GEN_PASS_DEF_PROPAGATESLICEINDICESPASS -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h.inc" +#include "xla/codegen/emitters/transforms/passes.h.inc" namespace { @@ -76,5 +76,5 @@ std::unique_ptr CreatePropagateSliceIndicesPass() { return std::make_unique(); } -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/xla/backends/gpu/codegen/emitters/transforms/simplify_affine.cc b/xla/codegen/emitters/transforms/simplify_affine.cc similarity index 97% rename from xla/backends/gpu/codegen/emitters/transforms/simplify_affine.cc rename to xla/codegen/emitters/transforms/simplify_affine.cc index ad54c074dd8c3..cd6b19cc4621f 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/simplify_affine.cc +++ b/xla/codegen/emitters/transforms/simplify_affine.cc @@ -46,13 +46,12 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/backends/gpu/codegen/emitters/ir/xla_gpu_ops.h" -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h" #include "xla/codegen/emitters/ir/xla_ops.h" +#include "xla/codegen/emitters/transforms/passes.h" #include "xla/hlo/analysis/indexing_map.h" namespace xla { -namespace gpu { +namespace emitters { namespace { using mlir::AffineBinaryOpExpr; @@ -78,7 +77,7 @@ using mlir::affine::AffineApplyOp; namespace arith = mlir::arith; #define GEN_PASS_DEF_SIMPLIFYAFFINEPASS -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h.inc" +#include "xla/codegen/emitters/transforms/passes.h.inc" int Distance(ImplicitLocOpBuilder& builder, Value a) { auto* block = builder.getInsertionBlock(); @@ -326,5 +325,5 @@ std::unique_ptr CreateSimplifyAffinePass() { return std::make_unique(); } -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/xla/backends/gpu/codegen/emitters/transforms/simplify_arith.cc b/xla/codegen/emitters/transforms/simplify_arith.cc similarity index 98% rename from xla/backends/gpu/codegen/emitters/transforms/simplify_arith.cc rename to xla/codegen/emitters/transforms/simplify_arith.cc index 094b091d33315..9e2880072b2b9 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/simplify_arith.cc +++ b/xla/codegen/emitters/transforms/simplify_arith.cc @@ -32,16 +32,16 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/backends/gpu/codegen/emitters/ir/xla_gpu_ops.h" // IWYU pragma: keep -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h" +#include "xla/codegen/emitters/ir/xla_ops.h" // IWYU pragma: keep +#include "xla/codegen/emitters/transforms/passes.h" #include "xla/hlo/analysis/indexing_map.h" namespace xla { -namespace gpu { +namespace emitters { namespace { #define GEN_PASS_DEF_SIMPLIFYARITHPASS -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h.inc" +#include "xla/codegen/emitters/transforms/passes.h.inc" using mlir::LogicalResult; using mlir::OpRewritePattern; @@ -383,5 +383,5 @@ std::unique_ptr CreateSimplifyArithPass() { return std::make_unique(); } -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/xla/backends/gpu/codegen/emitters/transforms/tests/convert_xla_gpu_pure_calls.mlir b/xla/codegen/emitters/transforms/tests/convert_pure_calls_ops.mlir similarity index 90% rename from xla/backends/gpu/codegen/emitters/transforms/tests/convert_xla_gpu_pure_calls.mlir rename to xla/codegen/emitters/transforms/tests/convert_pure_calls_ops.mlir index 356f0fab167ce..31cf39d2f75a7 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/tests/convert_xla_gpu_pure_calls.mlir +++ b/xla/codegen/emitters/transforms/tests/convert_pure_calls_ops.mlir @@ -1,5 +1,5 @@ -// RUN: emitters_opt %s -xla-gpu-convert-pure-call-ops | FileCheck %s -// RUN: emitters_opt %s -cse -xla-gpu-convert-pure-call-ops \ +// RUN: emitters_opt %s -xla-convert-pure-call-ops | FileCheck %s +// RUN: emitters_opt %s -cse -xla-convert-pure-call-ops \ // RUN: | FileCheck %s -check-prefixes=CHECK-CSE func.func private @callee() -> f32 { diff --git a/xla/backends/gpu/codegen/emitters/transforms/tests/lower_xla_gpu_loops_to_scf.mlir b/xla/codegen/emitters/transforms/tests/lower_xla_loops_to_scf.mlir similarity index 98% rename from xla/backends/gpu/codegen/emitters/transforms/tests/lower_xla_gpu_loops_to_scf.mlir rename to xla/codegen/emitters/transforms/tests/lower_xla_loops_to_scf.mlir index 8829e74a675f3..798e72fc1e301 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/tests/lower_xla_gpu_loops_to_scf.mlir +++ b/xla/codegen/emitters/transforms/tests/lower_xla_loops_to_scf.mlir @@ -1,4 +1,4 @@ -// RUN: emitters_opt %s -xla-gpu-lower-xla-gpu-loops-to-scf \ +// RUN: emitters_opt %s -xla-lower-xla-loops-to-scf \ // RUN: --split-input-file | FileCheck %s #map = #xla.indexing_map<"(d0)[s0, s1] -> (s0 + 1, s1 - 1)," diff --git a/xla/backends/gpu/codegen/emitters/transforms/tests/lower_xla_gpu_to_scf.mlir b/xla/codegen/emitters/transforms/tests/lower_xla_to_scf.mlir similarity index 99% rename from xla/backends/gpu/codegen/emitters/transforms/tests/lower_xla_gpu_to_scf.mlir rename to xla/codegen/emitters/transforms/tests/lower_xla_to_scf.mlir index 12e395ff58ae4..b67c99d3b60b5 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/tests/lower_xla_gpu_to_scf.mlir +++ b/xla/codegen/emitters/transforms/tests/lower_xla_to_scf.mlir @@ -1,4 +1,4 @@ -// RUN: emitters_opt %s -xla-gpu-lower-xla-gpu-to-scf --split-input-file \ +// RUN: emitters_opt %s -xla-lower-xla-to-scf --split-input-file \ // RUN: | FileCheck %s func.func @combiner(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) { diff --git a/xla/backends/gpu/codegen/emitters/transforms/tests/merge_pointers_to_same_slice.mlir b/xla/codegen/emitters/transforms/tests/merge_pointers_to_same_slice.mlir similarity index 98% rename from xla/backends/gpu/codegen/emitters/transforms/tests/merge_pointers_to_same_slice.mlir rename to xla/codegen/emitters/transforms/tests/merge_pointers_to_same_slice.mlir index b9f77437d8cef..2c04922a9c151 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/tests/merge_pointers_to_same_slice.mlir +++ b/xla/codegen/emitters/transforms/tests/merge_pointers_to_same_slice.mlir @@ -1,4 +1,4 @@ -// RUN: emitters_opt %s -split-input-file -xla-lower-tensors -xla-gpu-merge-pointers | FileCheck %s +// RUN: emitters_opt %s -split-input-file -xla-lower-tensors -xla-merge-pointers | FileCheck %s module { func.func private @tensorargs(%arg0: tensor<43xf32> {xla.slice_index = 0}, diff --git a/xla/backends/gpu/codegen/emitters/transforms/tests/peel_loops.mlir b/xla/codegen/emitters/transforms/tests/peel_loops.mlir similarity index 97% rename from xla/backends/gpu/codegen/emitters/transforms/tests/peel_loops.mlir rename to xla/codegen/emitters/transforms/tests/peel_loops.mlir index c799f409c7a47..9eeb57409aeab 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/tests/peel_loops.mlir +++ b/xla/codegen/emitters/transforms/tests/peel_loops.mlir @@ -1,4 +1,4 @@ -// RUN: emitters_opt -split-input-file %s -xla-gpu-peel-loops \ +// RUN: emitters_opt -split-input-file %s -xla-peel-loops \ // RUN: | FileCheck %s #map = #xla.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain:" diff --git a/xla/backends/gpu/codegen/emitters/transforms/tests/propagate_slice_indices.mlir b/xla/codegen/emitters/transforms/tests/propagate_slice_indices.mlir similarity index 94% rename from xla/backends/gpu/codegen/emitters/transforms/tests/propagate_slice_indices.mlir rename to xla/codegen/emitters/transforms/tests/propagate_slice_indices.mlir index 9bf77fdc2e291..637427c856929 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/tests/propagate_slice_indices.mlir +++ b/xla/codegen/emitters/transforms/tests/propagate_slice_indices.mlir @@ -1,4 +1,4 @@ -// RUN: emitters_opt %s -split-input-file -xla-gpu-propagate-slice-indices | FileCheck %s +// RUN: emitters_opt %s -split-input-file -xla-propagate-slice-indices | FileCheck %s module { func.func private @add(%arg0: f32, %arg1: f32) -> f32 { diff --git a/xla/backends/gpu/codegen/emitters/transforms/tests/simplify_affine.mlir b/xla/codegen/emitters/transforms/tests/simplify_affine.mlir similarity index 98% rename from xla/backends/gpu/codegen/emitters/transforms/tests/simplify_affine.mlir rename to xla/codegen/emitters/transforms/tests/simplify_affine.mlir index 363de4dba428d..aa6c54b490190 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/tests/simplify_affine.mlir +++ b/xla/codegen/emitters/transforms/tests/simplify_affine.mlir @@ -1,4 +1,4 @@ -// RUN: emitters_opt --allow-unregistered-dialect %s -split-input-file -xla-gpu-simplify-affine | FileCheck %s +// RUN: emitters_opt --allow-unregistered-dialect %s -split-input-file -xla-simplify-affine | FileCheck %s func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { %c0 = arith.constant 0 : index @@ -146,4 +146,4 @@ func.func @order_summands(%arg1: index) { // CHECK: arith.divui // CHECK: arith.addi // CHECK: arith.muli %[[ARG3]] -// CHECK: arith.addi %5, %6 : index \ No newline at end of file +// CHECK: arith.addi %5, %6 : index diff --git a/xla/backends/gpu/codegen/emitters/transforms/tests/simplify_arith.mlir b/xla/codegen/emitters/transforms/tests/simplify_arith.mlir similarity index 99% rename from xla/backends/gpu/codegen/emitters/transforms/tests/simplify_arith.mlir rename to xla/codegen/emitters/transforms/tests/simplify_arith.mlir index bdecdfea064f7..4420729c61adc 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/tests/simplify_arith.mlir +++ b/xla/codegen/emitters/transforms/tests/simplify_arith.mlir @@ -1,4 +1,4 @@ -// RUN: emitters_opt %s -split-input-file -xla-gpu-simplify-arith -cse \ +// RUN: emitters_opt %s -split-input-file -xla-simplify-arith -cse \ // RUN: -canonicalize | FileCheck %s module { diff --git a/xla/backends/gpu/codegen/emitters/transforms/tests/unswitch_loops.mlir b/xla/codegen/emitters/transforms/tests/unswitch_loops.mlir similarity index 94% rename from xla/backends/gpu/codegen/emitters/transforms/tests/unswitch_loops.mlir rename to xla/codegen/emitters/transforms/tests/unswitch_loops.mlir index 8220e8ed3d1e6..3f2ff8f70d890 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/tests/unswitch_loops.mlir +++ b/xla/codegen/emitters/transforms/tests/unswitch_loops.mlir @@ -1,4 +1,4 @@ -// RUN: emitters_opt %s -split-input-file -xla-gpu-unswitch-loops | FileCheck %s +// RUN: emitters_opt %s -split-input-file -xla-unswitch-loops | FileCheck %s module { func.func @unswitchable( diff --git a/xla/backends/gpu/codegen/emitters/transforms/unswitch_loops.cc b/xla/codegen/emitters/transforms/unswitch_loops.cc similarity index 97% rename from xla/backends/gpu/codegen/emitters/transforms/unswitch_loops.cc rename to xla/codegen/emitters/transforms/unswitch_loops.cc index 40987e119c72f..47182b86f3735 100644 --- a/xla/backends/gpu/codegen/emitters/transforms/unswitch_loops.cc +++ b/xla/codegen/emitters/transforms/unswitch_loops.cc @@ -27,10 +27,10 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace xla { -namespace gpu { +namespace emitters { #define GEN_PASS_DEF_UNSWITCHLOOPSPASS -#include "xla/backends/gpu/codegen/emitters/transforms/passes.h.inc" +#include "xla/codegen/emitters/transforms/passes.h.inc" namespace { @@ -102,5 +102,5 @@ CreateUnswitchLoopsPass() { return std::make_unique(); } -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/xla/tsl/concurrency/BUILD b/xla/tsl/concurrency/BUILD index 1ebdf1333bde3..2b505d8e07436 100644 --- a/xla/tsl/concurrency/BUILD +++ b/xla/tsl/concurrency/BUILD @@ -17,6 +17,7 @@ filegroup( "concurrent_vector.h", "ref_count.h", ], + compatible_with = get_compatible_with_portable(), ) filegroup( @@ -25,6 +26,7 @@ filegroup( "async_value.cc", "async_value_ref.cc", ], + compatible_with = get_compatible_with_portable(), ) cc_library(