Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:cpu] kernel_api_ir_builder: expose helpers to get KernelParams #22194

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions xla/backends/cpu/codegen/emitters/ir/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -16,6 +18,7 @@ package_group(
td_library(
name = "xla_cpu_td_files",
srcs = glob(["*.td"]),
compatible_with = get_compatible_with_portable(),
includes = ["."],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
Expand All @@ -25,6 +28,7 @@ td_library(

gentbl_cc_library(
name = "xla_cpu_dialect_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -43,6 +47,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_cpu_types_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -67,6 +72,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_cpu_ops_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down
3 changes: 3 additions & 0 deletions xla/backends/cpu/codegen/emitters/transforms/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -15,6 +17,7 @@ package_group(

gentbl_cc_library(
name = "passes_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
[
Expand Down
24 changes: 12 additions & 12 deletions xla/backends/cpu/codegen/kernel_api_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,39 +243,39 @@ absl::StatusOr<BufferAllocation::Slice> GetUniqueSlice(
return buffer_assignment->GetUniqueSlice(instruction, index);
}

} // namespace

absl::StatusOr<std::vector<KernelApiIrBuilder::KernelParameter>>
GetKernelArgumentsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelApiIrBuilder::KernelParameter> arguments;
KernelApiIrBuilder::GetKernelArgumentsParameters(
const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelParameter> arguments;

for (HloInstruction* operand : instruction->operands()) {
for (auto& indexed : ShapeUtil::GetLeafShapes(operand->shape())) {
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice slice,
GetUniqueSlice(buffer_assignment, operand, indexed.index));
arguments.push_back(
KernelApiIrBuilder::KernelParameter{indexed.shape, slice});
arguments.push_back(KernelParameter{indexed.shape, slice});
}
}
return arguments;
}

absl::StatusOr<std::vector<KernelApiIrBuilder::KernelParameter>>
GetKernelResultsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelApiIrBuilder::KernelParameter> results;
KernelApiIrBuilder::GetKernelResultsParameters(
const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelParameter> results;
for (auto& indexed : ShapeUtil::GetLeafShapes(instruction->shape())) {
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice slice,
GetUniqueSlice(buffer_assignment, instruction, indexed.index));
results.push_back(
KernelApiIrBuilder::KernelParameter{indexed.shape, slice});
results.push_back(KernelParameter{indexed.shape, slice});
}
return results;
}

} // namespace

auto KernelApiIrBuilder::Options::FromHloModuleConfig(
const HloModuleConfig& config) -> Options {
return KernelApiIrBuilder::Options{
Expand Down
8 changes: 8 additions & 0 deletions xla/backends/cpu/codegen/kernel_api_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ class KernelApiIrBuilder {
static std::unique_ptr<llvm::Module> CreateModule(absl::string_view name,
llvm::LLVMContext& context);

static absl::StatusOr<std::vector<KernelParameter>>
GetKernelArgumentsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment);

static absl::StatusOr<std::vector<KernelParameter>>
GetKernelResultsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment);

private:
ThreadDims EmitKernelThreadDims(llvm::IRBuilderBase& builder,
llvm::Value* call_frame);
Expand Down
28 changes: 28 additions & 0 deletions xla/backends/cpu/codegen/kernel_api_ir_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
Expand Down Expand Up @@ -294,5 +295,32 @@ TEST_F(KernelApiIrBuilderTest, MixedBuffers) {
EXPECT_TRUE(prototype.invariant_arguments.contains(0));
}

TEST_F(KernelApiIrBuilderTest, GetKernelParams) {
llvm::LLVMContext context;
auto module = std::make_unique<llvm::Module>("test", context);
constexpr absl::string_view hlo_text = R"(
HloModule m
ENTRY main {
p0 = f32[2,2] parameter(0)
p1 = f32[2,2] parameter(1)
ROOT add.0 = f32[2,2] add(p0, p1)
})";

TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(auto buffer_assignment, RunBufferAssignment(*hlo));
const auto* root = hlo->entry_computation()->root_instruction();
TF_ASSERT_OK_AND_ASSIGN(auto args,
KernelApiIrBuilder::GetKernelArgumentsParameters(
root, buffer_assignment.get()));
EXPECT_EQ(args.size(), 2);
EXPECT_THAT(args[0].shape.dimensions(), ::testing::ElementsAre(2, 2));
EXPECT_THAT(args[1].shape.dimensions(), ::testing::ElementsAre(2, 2));
TF_ASSERT_OK_AND_ASSIGN(auto results,
KernelApiIrBuilder::GetKernelResultsParameters(
root, buffer_assignment.get()));
EXPECT_EQ(results.size(), 1);
EXPECT_THAT(results[0].shape.dimensions(), ::testing::ElementsAre(2, 2));
}

} // namespace
} // namespace xla::cpu
22 changes: 11 additions & 11 deletions xla/backends/gpu/codegen/emitters/emitter_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,33 +574,33 @@ absl::Status EmitterBase::RunPassPipeline(
}

void AddXlaGpuOpsOptimizationPasses(mlir::OpPassManager& pm) {
pm.addNestedPass<FuncOp>(CreateSimplifyArithPass());
pm.addNestedPass<FuncOp>(emitters::CreateSimplifyArithPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(CreateEraseDeadFunctionsPass());
pm.addPass(emitters::CreateEraseDeadFunctionsPass());
pm.addPass(mlir::createCSEPass());
}

void AddLoopTransformationPasses(mlir::OpPassManager& pm,
const se::DeviceDescription& device) {
pm.addNestedPass<FuncOp>(
CreateLowerXlaGpuToScfPass(device.threads_per_warp()));
emitters::CreateLowerXlaToScfPass(device.threads_per_warp()));
pm.addNestedPass<FuncOp>(CreateFuseLoopsPass());
pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) {
// CSE after inlining because inlining can introduce duplicates.
pm.addPass(mlir::createCSEPass());
}));
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addNestedPass<FuncOp>(CreatePeelLoopsPass());
pm.addNestedPass<FuncOp>(CreateLowerXlaGpuLoopsToScfPass());
pm.addNestedPass<FuncOp>(emitters::CreatePeelLoopsPass());
pm.addNestedPass<FuncOp>(emitters::CreateLowerXlaLoopsToScfPass());
pm.addPass(mlir::mhlo::createConvertToSignlessPass());
pm.addPass(CreatePropagateSliceIndicesPass());
pm.addPass(emitters::CreatePropagateSliceIndicesPass());
pm.addPass(emitters::CreateFlattenTensorsPass());
// We need LICM before unswitching loops, because our loop unswitcher only
// detects for loops with a single if inside them.
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
pm.addNestedPass<FuncOp>(CreateUnswitchLoopsPass());
pm.addNestedPass<FuncOp>(emitters::CreateUnswitchLoopsPass());
// We need LICM again after unswitching, because that can introduce new
// opportunities for LICM. This would not be necessary if LICM also moved
// instructions over ifs.
Expand All @@ -613,17 +613,17 @@ void AddLoopTransformationPasses(mlir::OpPassManager& pm,

void AddLoweringPasses(mlir::OpPassManager& pm,
const se::DeviceDescription& device) {
pm.addNestedPass<FuncOp>(CreateConvertPureCallOpsPass());
pm.addNestedPass<FuncOp>(emitters::CreateConvertPureCallOpsPass());
pm.addPass(emitters::CreateLowerTensorsPass(device));
pm.addPass(mlir::createConvertComplexToStandardPass());
pm.addPass(CreateMergePointersToSameSlicePass());
pm.addPass(emitters::CreateMergePointersToSameSlicePass());

// LowerTensors creates new affine.apply ops. Fold and CSE them so
// simplify-affine has maximally folded expressions to work with.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addNestedPass<FuncOp>(CreateSimplifyArithPass());
pm.addPass(CreateSimplifyAffinePass());
pm.addNestedPass<FuncOp>(emitters::CreateSimplifyArithPass());
pm.addPass(emitters::CreateSimplifyAffinePass());
pm.addPass(CreateConvertIndexTypePass());
// simplify-affine lowers most affine.apply ops, but if it can't prove a
// division or modulo is unsigned, affine.apply ops will remain.
Expand Down
7 changes: 7 additions & 0 deletions xla/backends/gpu/codegen/emitters/ir/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -16,6 +18,7 @@ package_group(
td_library(
name = "xla_gpu_td_files",
srcs = glob(["*.td"]),
compatible_with = get_compatible_with_portable(),
includes = ["."],
deps = [
"//xla/codegen/emitters/ir:xla_td_files",
Expand All @@ -30,6 +33,7 @@ td_library(

gentbl_cc_library(
name = "xla_gpu_dialect_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -48,6 +52,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_ops_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -66,6 +71,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_attrs_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down Expand Up @@ -98,6 +104,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_types_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce

add {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
11 changes: 0 additions & 11 deletions xla/backends/gpu/codegen/emitters/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,8 @@ cc_library(
srcs = [
"convert_float_nvidia.cc",
"convert_index_type.cc",
"convert_xla_gpu_pure_call_ops.cc",
"erase_dead_functions.cc",
"fuse_loops.cc",
"lower_xla_gpu_to_scf.cc",
"merge_pointers_to_same_slice.cc",
"optimize_loops.cc",
"peel_loops.cc",
"propagate_slice_indices.cc",
"simplify_affine.cc",
"simplify_arith.cc",
"unswitch_loops.cc",
"vectorize_loads_stores.cc",
],
hdrs = ["passes.h"],
Expand All @@ -61,7 +52,6 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/codegen/emitters/ir:xla_gpu",
"//xla/codegen/emitters:elemental_hlo_to_mlir",
"//xla/codegen/emitters/ir:xla",
"//xla/codegen/emitters/transforms:atomic_rmw_utils",
"//xla/hlo/analysis:indexing_analysis",
Expand All @@ -74,7 +64,6 @@ cc_library(
"//xla/stream_executor:semantic_version",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
Expand Down
10 changes: 0 additions & 10 deletions xla/backends/gpu/codegen/emitters/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,8 @@ std::unique_ptr<mlir::Pass> CreateConvertFloatNvidiaPass();
std::optional<std::unique_ptr<mlir::Pass>> MaybeCreateConvertFloatNvidiaPass(
const se::DeviceDescription& device_description);
std::unique_ptr<mlir::Pass> CreateConvertIndexTypePass();
std::unique_ptr<mlir::Pass> CreateConvertPureCallOpsPass();
std::unique_ptr<mlir::Pass> CreateEraseDeadFunctionsPass();
std::unique_ptr<mlir::Pass> CreateLowerXlaGpuToScfPass(int64_t warp_size = 32);
std::unique_ptr<mlir::Pass> CreateLowerXlaGpuLoopsToScfPass();
std::unique_ptr<mlir::Pass> CreateMergePointersToSameSlicePass();
std::unique_ptr<mlir::Pass> CreateOptimizeLoopsPass();
std::unique_ptr<mlir::Pass> CreateFuseLoopsPass();
std::unique_ptr<mlir::Pass> CreatePeelLoopsPass();
std::unique_ptr<mlir::Pass> CreatePropagateSliceIndicesPass();
std::unique_ptr<mlir::Pass> CreateSimplifyAffinePass();
std::unique_ptr<mlir::Pass> CreateSimplifyArithPass();
std::unique_ptr<mlir::Pass> CreateUnswitchLoopsPass();
std::unique_ptr<mlir::Pass> CreateVectorizeLoadsAndStoresPass(
const std::string& gpu_device_info = "");
std::unique_ptr<mlir::Pass> CreateVectorizeLoadsAndStoresPass(
Expand Down
Loading
Loading