Skip to content

Commit

Permalink
Introduce OpaqueExecutable.
Browse files Browse the repository at this point in the history
At the moment the HloRunnerInterface is tightly coupled with the Executable
class. The PjRt runner actually consumes PjRtExecutable instances, so the idea
with OpaqueExecutable is to hide away the implementation details and
effectively return a wrapper class that hides all of this implementation detail.

The handle class records the owning HloRunnerInterface implementation so that we
can check that executables created with one runner cannot be used by another
runner. By doing this we can ensure users cannot pass in externally-created
executables which may not be of the type that we need for a given runner
implementation.

PiperOrigin-RevId: 721127545
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Jan 31, 2025
1 parent 811a86b commit bca9fe0
Show file tree
Hide file tree
Showing 27 changed files with 452 additions and 189 deletions.
1 change: 1 addition & 0 deletions xla/backends/gpu/codegen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ xla_test(
"//xla/service:executable",
"//xla/service:hlo_module_config",
"//xla/service:hlo_proto_cc",
"//xla/service:hlo_runner_interface",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:gpu_executable",
"//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
Expand Down
8 changes: 6 additions & 2 deletions xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_memory.h"
Expand Down Expand Up @@ -3244,10 +3245,13 @@ TEST_F(DynamicSliceFusionTest,

// Check that the offset value in the thunk is an evaluated constant even if
// no simplification passes are executed.
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Executable> exec,
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<OpaqueExecutable> wrapped_executable,
CreateExecutable(/*module=*/module_opt->Clone(),
/*run_hlo_passes=*/false));
GpuExecutable* gpu_exec = dynamic_cast<GpuExecutable*>(exec.get());
TF_ASSERT_OK_AND_ASSIGN(Executable* const exec,
test_runner_as_hlo_runner().ExecutableFromWrapped(
wrapped_executable.get()));
GpuExecutable* gpu_exec = dynamic_cast<GpuExecutable*>(exec);
ASSERT_NE(gpu_exec, nullptr);
const SequentialThunk& thunk = gpu_exec->GetThunk();
auto dynamic_slice_thunk =
Expand Down
6 changes: 3 additions & 3 deletions xla/hlo/testlib/hlo_hardware_independent_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ HloComputation* HloHardwareIndependentTestBase::FindComputation(

/* static */
HloInstruction* HloHardwareIndependentTestBase::FindInstruction(
HloModule* module, absl::string_view name) {
const HloModule* module, absl::string_view name) {
for (const HloComputation* computation : module->computations()) {
if (HloInstruction* instruction =
hlo_query::FindInstruction(computation, name)) {
Expand All @@ -334,7 +334,7 @@ HloInstruction* HloHardwareIndependentTestBase::FindInstruction(

/* static */
HloInstruction* HloHardwareIndependentTestBase::FindInstruction(
HloModule* module, HloOpcode opcode) {
const HloModule* module, HloOpcode opcode) {
for (const HloComputation* computation : module->computations()) {
if (HloInstruction* instruction =
hlo_query::FindInstruction(computation, opcode)) {
Expand All @@ -346,7 +346,7 @@ HloInstruction* HloHardwareIndependentTestBase::FindInstruction(

/* static */
std::vector<HloInstruction*> HloHardwareIndependentTestBase::FindInstructions(
HloModule* module, HloOpcode opcode) {
const HloModule* module, HloOpcode opcode) {
std::vector<HloInstruction*> instructions;
for (const HloComputation* c : module->computations()) {
absl::c_copy_if(c->instructions(), std::back_inserter(instructions),
Expand Down
7 changes: 4 additions & 3 deletions xla/hlo/testlib/hlo_hardware_independent_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,13 @@ class HloHardwareIndependentTestBase : public ::testing::Test {
// inspect a particular computation or instruction.
static HloComputation* FindComputation(HloModule* module,
absl::string_view name);
static HloInstruction* FindInstruction(HloModule* module,
static HloInstruction* FindInstruction(const HloModule* module,
absl::string_view name);
// Gets the instruction from the given module with the given opcode.
static HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode);
static HloInstruction* FindInstruction(const HloModule* module,
HloOpcode opcode);
// Gets all the instructions from the given module with the given opcode.
static std::vector<HloInstruction*> FindInstructions(HloModule* module,
static std::vector<HloInstruction*> FindInstructions(const HloModule* module,
HloOpcode opcode);

protected:
Expand Down
7 changes: 6 additions & 1 deletion xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4634,7 +4634,6 @@ cc_library(
hdrs = ["hlo_runner_interface.h"],
deps = [
":computation_placer",
":executable",
":hlo_module_config",
"//xla:literal",
"//xla:shape_util",
Expand All @@ -4645,7 +4644,10 @@ cc_library(
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:die_if_null",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -4677,6 +4679,9 @@ cc_library(
"//xla/stream_executor:device_memory_allocator",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor:stream_executor_memory_allocator",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1974,6 +1974,7 @@ xla_cc_test(
"//xla/service:compiler",
"//xla/service:executable",
"//xla/service:gpu_plugin",
"//xla/service:hlo_runner_interface",
"//xla/service:platform_util",
"//xla/stream_executor:platform",
"//xla/stream_executor:platform_manager",
Expand Down
9 changes: 6 additions & 3 deletions xla/service/gpu/gpu_aot_compilation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "xla/literal_util.h"
#include "xla/service/compiler.h"
#include "xla/service/executable.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/service/platform_util.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
Expand Down Expand Up @@ -236,14 +237,16 @@ TEST_F(GpuAotCompilationTest, ExportAndLoadExecutableWithTriton) {
// Load Executable from AOT compilation result.
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Executable> executable,
aot_result->LoadExecutable(compiler, stream_exec));
std::unique_ptr<OpaqueExecutable> wrapped_executable =
test_runner_as_hlo_runner().WrapExecutable(std::move(executable));

const xla::Literal literal_1 = xla::LiteralUtil::CreateR0<float>(1.0f);
const xla::Literal literal_2 = xla::LiteralUtil::CreateR0<float>(2.0f);
const xla::Literal literal_3 = xla::LiteralUtil::CreateR0<float>(3.0f);

TF_ASSERT_OK_AND_ASSIGN(Literal result,
GetHloRunner().value()->ExecuteWithExecutable(
executable.get(), {&literal_1, &literal_3}));
TF_ASSERT_OK_AND_ASSIGN(
Literal result, test_runner_as_hlo_runner().ExecuteWithExecutable(
wrapped_executable.get(), {&literal_1, &literal_3}));

EXPECT_TRUE(LiteralTestUtil::Equal(
LiteralUtil::MakeTuple({&literal_2, &literal_3}), result));
Expand Down
6 changes: 4 additions & 2 deletions xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1037,15 +1037,17 @@ ENTRY e {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Executable> executable,
aot_result->LoadExecutable(compiler, aot_options.executor()));
std::unique_ptr<OpaqueExecutable> wrapped_executable =
test_runner_as_hlo_runner().WrapExecutable(std::move(executable));

const xla::Literal literal_input =
xla::LiteralUtil::CreateR0<int32_t>(input);
const xla::Literal literal_expected_result =
xla::LiteralUtil::CreateR0<int32_t>(expected_result);

TF_ASSERT_OK_AND_ASSIGN(Literal result,
GetHloRunner().value()->ExecuteWithExecutable(
executable.get(), {&literal_input}));
test_runner_as_hlo_runner().ExecuteWithExecutable(
wrapped_executable.get(), {&literal_input}));

EXPECT_TRUE(LiteralTestUtil::Equal(result, literal_expected_result));
};
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,7 @@ cc_library(
"//xla/service:gpu_plugin",
"//xla/service:hlo_module_config",
"//xla/service:hlo_runner",
"//xla/service:hlo_runner_interface",
"//xla/service:hlo_verifier",
"//xla/service:interpreter_plugin",
"//xla/stream_executor:device_description",
Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/model/hlo_op_profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "xla/service/gpu/model/hlo_op_profile.pb.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_runner.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/service/hlo_verifier.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -179,7 +180,7 @@ absl::StatusOr<absl::Duration> HloOpProfiler::MeasureOpChainDuration(
/*use_large_range=*/true)
.value();
const absl::Time t_compile_start = absl::Now();
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> ex,
TF_ASSIGN_OR_RETURN(std::unique_ptr<OpaqueExecutable> ex,
runner_.CreateExecutable(std::move(module),
/*run_hlo_passes=*/false));
if (absl::Now() - t_compile_start > absl::Seconds(10)) {
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,8 @@ xla_test(
"//xla/hlo/parser:hlo_parser",
"//xla/hlo/testlib:filecheck",
"//xla/hlo/testlib:verified_hlo_module",
"//xla/service:executable",
"//xla/service:hlo_runner_interface",
"//xla/service/gpu:gpu_device_info_for_tests",
"//xla/service/gpu:gpu_executable",
"//xla/stream_executor:device_description",
Expand Down
14 changes: 12 additions & 2 deletions xla/service/gpu/transforms/command_buffer_scheduling_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ limitations under the License.
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/testlib/filecheck.h"
#include "xla/hlo/testlib/verified_hlo_module.h"
#include "xla/service/executable.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/gpu_executable.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tsl/lib/core/status_test_util.h"
Expand Down Expand Up @@ -1158,7 +1160,11 @@ TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionDynamicSlicing) {
HloModuleConfig config(m_clone->config());
config.set_debug_options(options);
m_clone->set_config(config);
TF_ASSIGN_OR_RETURN(auto exec, CreateExecutable(std::move(m_clone), false));
TF_ASSIGN_OR_RETURN(std::unique_ptr<OpaqueExecutable> wrapped_exec,
CreateExecutable(std::move(m_clone), false));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> exec,
test_runner_as_hlo_runner().ExecutableFromWrapped(
std::move(wrapped_exec)));
auto gpu_exec = std::unique_ptr<GpuExecutable>(
static_cast<GpuExecutable*>(exec.release()));
TF_RET_CHECK(llvm::any_of(gpu_exec->GetThunk().thunks(),
Expand Down Expand Up @@ -1226,7 +1232,11 @@ TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionStaticSlicing) {
HloModuleConfig config(m_clone->config());
config.set_debug_options(options);
m_clone->set_config(config);
TF_ASSIGN_OR_RETURN(auto exec, CreateExecutable(std::move(m_clone), false));
TF_ASSIGN_OR_RETURN(std::unique_ptr<OpaqueExecutable> wrapped_exec,
CreateExecutable(std::move(m_clone), false));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> exec,
test_runner_as_hlo_runner().ExecutableFromWrapped(
std::move(wrapped_exec)));
return std::unique_ptr<GpuExecutable>(
static_cast<GpuExecutable*>(exec.release()));
};
Expand Down
Loading

0 comments on commit bca9fe0

Please sign in to comment.