Skip to content

Commit

Permalink
[xla:emitters] tag XLA, XLA:CPU and XLA:GPU dialects as non-prod-comp…
Browse files Browse the repository at this point in the history
…atible

This paves the way for XLA:CPU fusion emitters.

Note that XLA:CPU is non-prod-compatible, whereas XLA:GPU is not. The
CPU fusion emitters will depend on the XLA, XLA:CPU and XLA:GPU dialects,
and given that the emitters' dependents in XLA:CPU are non-prod-compatible,
the three dialects have to be as well.

XLA:CPU passes also have to be tagged. Crucially, XLA:GPU passes are not
used by any of the above dialects nor by XLA:CPU passes, so XLA:GPU
remains essentially untouched; we just tag the XLA:GPU dialect.

Some common libraries in xla/codegen/emitters are also tagged.

PiperOrigin-RevId: 721954339
  • Loading branch information
cota authored and Google-ML-Automation committed Feb 1, 2025
1 parent c1ef7cc commit 170e331
Show file tree
Hide file tree
Showing 38 changed files with 360 additions and 619 deletions.
6 changes: 6 additions & 0 deletions xla/backends/cpu/codegen/emitters/ir/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -16,6 +18,7 @@ package_group(
td_library(
name = "xla_cpu_td_files",
srcs = glob(["*.td"]),
compatible_with = get_compatible_with_portable(),
includes = ["."],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
Expand All @@ -25,6 +28,7 @@ td_library(

gentbl_cc_library(
name = "xla_cpu_dialect_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -43,6 +47,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_cpu_types_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -67,6 +72,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_cpu_ops_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down
3 changes: 3 additions & 0 deletions xla/backends/cpu/codegen/emitters/transforms/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -15,6 +17,7 @@ package_group(

gentbl_cc_library(
name = "passes_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
[
Expand Down
22 changes: 11 additions & 11 deletions xla/backends/gpu/codegen/emitters/emitter_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,33 +574,33 @@ absl::Status EmitterBase::RunPassPipeline(
}

void AddXlaGpuOpsOptimizationPasses(mlir::OpPassManager& pm) {
pm.addNestedPass<FuncOp>(CreateSimplifyArithPass());
pm.addNestedPass<FuncOp>(emitters::CreateSimplifyArithPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(CreateEraseDeadFunctionsPass());
pm.addPass(emitters::CreateEraseDeadFunctionsPass());
pm.addPass(mlir::createCSEPass());
}

void AddLoopTransformationPasses(mlir::OpPassManager& pm,
const se::DeviceDescription& device) {
pm.addNestedPass<FuncOp>(
CreateLowerXlaGpuToScfPass(device.threads_per_warp()));
emitters::CreateLowerXlaToScfPass(device.threads_per_warp()));
pm.addNestedPass<FuncOp>(CreateFuseLoopsPass());
pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) {
// CSE after inlining because inlining can introduce duplicates.
pm.addPass(mlir::createCSEPass());
}));
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addNestedPass<FuncOp>(CreatePeelLoopsPass());
pm.addNestedPass<FuncOp>(CreateLowerXlaGpuLoopsToScfPass());
pm.addNestedPass<FuncOp>(emitters::CreatePeelLoopsPass());
pm.addNestedPass<FuncOp>(emitters::CreateLowerXlaLoopsToScfPass());
pm.addPass(mlir::mhlo::createConvertToSignlessPass());
pm.addPass(CreatePropagateSliceIndicesPass());
pm.addPass(emitters::CreatePropagateSliceIndicesPass());
pm.addPass(emitters::CreateFlattenTensorsPass());
// We need LICM before unswitching loops, because our loop unswitcher only
// detects for loops with a single if inside them.
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
pm.addNestedPass<FuncOp>(CreateUnswitchLoopsPass());
pm.addNestedPass<FuncOp>(emitters::CreateUnswitchLoopsPass());
// We need LICM again after unswitching, because that can introduce new
// opportunities for LICM. This would not be necessary if LICM also moved
// instructions over ifs.
Expand All @@ -613,17 +613,17 @@ void AddLoopTransformationPasses(mlir::OpPassManager& pm,

void AddLoweringPasses(mlir::OpPassManager& pm,
const se::DeviceDescription& device) {
pm.addNestedPass<FuncOp>(CreateConvertPureCallOpsPass());
pm.addNestedPass<FuncOp>(emitters::CreateConvertPureCallOpsPass());
pm.addPass(emitters::CreateLowerTensorsPass(device));
pm.addPass(mlir::createConvertComplexToStandardPass());
pm.addPass(CreateMergePointersToSameSlicePass());
pm.addPass(emitters::CreateMergePointersToSameSlicePass());

// LowerTensors creates new affine.apply ops. Fold and CSE them so
// simplify-affine has maximally folded expressions to work with.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addNestedPass<FuncOp>(CreateSimplifyArithPass());
pm.addPass(CreateSimplifyAffinePass());
pm.addNestedPass<FuncOp>(emitters::CreateSimplifyArithPass());
pm.addPass(emitters::CreateSimplifyAffinePass());
pm.addPass(CreateConvertIndexTypePass());
// simplify-affine lowers most affine.apply ops, but if it can't prove a
// division or modulo is unsigned, affine.apply ops will remain.
Expand Down
7 changes: 7 additions & 0 deletions xla/backends/gpu/codegen/emitters/ir/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -16,6 +18,7 @@ package_group(
td_library(
name = "xla_gpu_td_files",
srcs = glob(["*.td"]),
compatible_with = get_compatible_with_portable(),
includes = ["."],
deps = [
"//xla/codegen/emitters/ir:xla_td_files",
Expand All @@ -30,6 +33,7 @@ td_library(

gentbl_cc_library(
name = "xla_gpu_dialect_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -48,6 +52,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_ops_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -66,6 +71,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_attrs_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down Expand Up @@ -98,6 +104,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_types_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce

add {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
11 changes: 0 additions & 11 deletions xla/backends/gpu/codegen/emitters/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,8 @@ cc_library(
srcs = [
"convert_float_nvidia.cc",
"convert_index_type.cc",
"convert_xla_gpu_pure_call_ops.cc",
"erase_dead_functions.cc",
"fuse_loops.cc",
"lower_xla_gpu_to_scf.cc",
"merge_pointers_to_same_slice.cc",
"optimize_loops.cc",
"peel_loops.cc",
"propagate_slice_indices.cc",
"simplify_affine.cc",
"simplify_arith.cc",
"unswitch_loops.cc",
"vectorize_loads_stores.cc",
],
hdrs = ["passes.h"],
Expand All @@ -61,7 +52,6 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/codegen/emitters/ir:xla_gpu",
"//xla/codegen/emitters:elemental_hlo_to_mlir",
"//xla/codegen/emitters/ir:xla",
"//xla/codegen/emitters/transforms:atomic_rmw_utils",
"//xla/hlo/analysis:indexing_analysis",
Expand All @@ -74,7 +64,6 @@ cc_library(
"//xla/stream_executor:semantic_version",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
Expand Down
10 changes: 0 additions & 10 deletions xla/backends/gpu/codegen/emitters/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,8 @@ std::unique_ptr<mlir::Pass> CreateConvertFloatNvidiaPass();
std::optional<std::unique_ptr<mlir::Pass>> MaybeCreateConvertFloatNvidiaPass(
const se::DeviceDescription& device_description);
std::unique_ptr<mlir::Pass> CreateConvertIndexTypePass();
std::unique_ptr<mlir::Pass> CreateConvertPureCallOpsPass();
std::unique_ptr<mlir::Pass> CreateEraseDeadFunctionsPass();
std::unique_ptr<mlir::Pass> CreateLowerXlaGpuToScfPass(int64_t warp_size = 32);
std::unique_ptr<mlir::Pass> CreateLowerXlaGpuLoopsToScfPass();
std::unique_ptr<mlir::Pass> CreateMergePointersToSameSlicePass();
std::unique_ptr<mlir::Pass> CreateOptimizeLoopsPass();
std::unique_ptr<mlir::Pass> CreateFuseLoopsPass();
std::unique_ptr<mlir::Pass> CreatePeelLoopsPass();
std::unique_ptr<mlir::Pass> CreatePropagateSliceIndicesPass();
std::unique_ptr<mlir::Pass> CreateSimplifyAffinePass();
std::unique_ptr<mlir::Pass> CreateSimplifyArithPass();
std::unique_ptr<mlir::Pass> CreateUnswitchLoopsPass();
std::unique_ptr<mlir::Pass> CreateVectorizeLoadsAndStoresPass(
const std::string& gpu_device_info = "");
std::unique_ptr<mlir::Pass> CreateVectorizeLoadsAndStoresPass(
Expand Down
Loading

0 comments on commit 170e331

Please sign in to comment.