Skip to content

Commit

Permalink
Introduce HloRunnerExecutableHandle.
Browse files Browse the repository at this point in the history
At the moment the HloRunnerInterface is tightly coupled with the Executable
class. The PjRt runner actually consumes PjRtExecutable instances, so the idea
with HloRunnerExecutableHandle is to hide away the implementation details and
effectively return a wrapper class that hides all of this implementation detail.

The handle class records the owning HloRunnerInterface implementation so that we
can check that executables created with one runner cannot be used by another
runner. By doing this we can ensure users cannot pass in externally-created
executables which may not be of the type that we need for a given runner
implementation.

PiperOrigin-RevId: 721127545
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Jan 30, 2025
1 parent 5631fbd commit 356965b
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 94 deletions.
5 changes: 4 additions & 1 deletion xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4633,7 +4633,6 @@ cc_library(
hdrs = ["hlo_runner_interface.h"],
deps = [
":computation_placer",
":executable",
":hlo_module_config",
"//xla:literal",
"//xla:shape_util",
Expand All @@ -4644,7 +4643,10 @@ cc_library(
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:die_if_null",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -4676,6 +4678,7 @@ cc_library(
"//xla/stream_executor:device_memory_allocator",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor:stream_executor_memory_allocator",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
Expand Down
90 changes: 67 additions & 23 deletions xla/service/hlo_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ limitations under the License.
#include <variant>
#include <vector>

#include "absl/base/nullability.h"
#include "absl/status/statusor.h"
#include "unsupported/Eigen/CXX11/Tensor"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/service/executable.h"
Expand All @@ -36,6 +38,28 @@ limitations under the License.

namespace xla {

namespace {
class HloRunnerExecutable : public HloRunnerExecutableHandle {
public:
HloRunnerExecutable(absl::Nonnull<const HloRunner*> creator,
std::unique_ptr<Executable> executable)
: HloRunnerExecutableHandle(creator),
executable_(std::move(executable)) {}

Executable* executable() const { return executable_.get(); }

static absl::StatusOr<absl::Nonnull<HloRunnerExecutable*>> FromHandle(
const HloRunnerInterface& runner,
absl::Nonnull<HloRunnerExecutableHandle*> const handle) {
return HloRunnerExecutableHandle::FromHandle<HloRunnerExecutable>(runner,
handle);
}

private:
std::unique_ptr<Executable> executable_;
};
} // namespace

HloRunner::HloRunner(se::Platform* platform, int intra_op_parallelism_threads) {
BackendOptions backend_options;
backend_options.set_platform(platform);
Expand Down Expand Up @@ -186,15 +210,17 @@ absl::StatusOr<Literal> HloRunner::ExecuteWithBufferAssignment(
}

absl::StatusOr<Literal> HloRunner::ExecuteWithExecutable(
Executable* executable, absl::Span<const Literal* const> arguments,
ExecutionProfile* profile) {
entry_computation_layout_ =
&(executable->module().entry_computation_layout());
HloRunnerExecutableHandle* executable,
absl::Span<const Literal* const> arguments, ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(HloRunnerExecutable* const hlo_runner_executable,
HloRunnerExecutable::FromHandle(*this, executable));
entry_computation_layout_ = &(
hlo_runner_executable->executable()->module().entry_computation_layout());
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
TransferLiteralsToDevice(arguments));
TF_ASSIGN_OR_RETURN(ExecutionOutput result,
ExecuteWithDeviceBuffers(
/*executable=*/executable,
/*executable=*/hlo_runner_executable->executable(),
/*arguments=*/argument_buffers,
/*profile=*/profile));
return TransferLiteralFromDevice(result.Result());
Expand Down Expand Up @@ -286,9 +312,12 @@ absl::StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
absl::Span<ScopedShapedBuffer const> arguments, bool run_hlo_passes,
ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloRunnerExecutableHandle> executable,
CreateExecutable(std::move(module), run_hlo_passes));
return ExecuteWithDeviceBuffers(executable.get(), arguments, profile);
TF_ASSIGN_OR_RETURN(HloRunnerExecutable* const hlo_runner_executable,
HloRunnerExecutable::FromHandle(*this, executable.get()));
return ExecuteWithDeviceBuffers(hlo_runner_executable->executable(),
arguments, profile);
}

absl::StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
Expand Down Expand Up @@ -319,11 +348,13 @@ HloRunner::ExecuteWithMovedDeviceBuffersAndBufferAssignment(
std::vector<ScopedShapedBuffer> arguments, bool run_hlo_passes,
ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
std::unique_ptr<HloRunnerExecutableHandle> executable,
CreateExecutableWithBufferAssignment(
std::move(module), buffer_assignment_proto, run_hlo_passes));
return ExecuteWithMovedDeviceBuffers(executable.get(), std::move(arguments),
profile);
TF_ASSIGN_OR_RETURN(HloRunnerExecutable* const hlo_runner_executable,
HloRunnerExecutable::FromHandle(*this, executable.get()));
return ExecuteWithMovedDeviceBuffers(hlo_runner_executable->executable(),
std::move(arguments), profile);
}

absl::StatusOr<ExecutionOutput> HloRunner::ExecuteWithMovedDeviceBuffers(
Expand Down Expand Up @@ -384,7 +415,7 @@ absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
std::unique_ptr<HloRunnerExecutableHandle> executable,
CreateExecutable(std::move(module), options.run_hlo_passes));
return ExecuteReplicated(executable.get(), options, device_assignment);
}
Expand Down Expand Up @@ -529,10 +560,14 @@ absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicatedImpl(
}

absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
Executable* executable, const ReplicatedExecuteOptions& options,
HloRunnerExecutableHandle* executable,
const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment, ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(HloRunnerExecutable * wrapped_executable,
HloRunnerExecutable::FromHandle(*this, executable));
return ExecuteReplicatedImpl(
[&](const std::vector<ServiceExecutableRunOptions>& service_run_options,
[&, executable = wrapped_executable->executable()](
const std::vector<ServiceExecutableRunOptions>& service_run_options,
const std::vector<absl::Span<const ShapedBuffer* const>>&
argument_buffer_slices)
-> absl::StatusOr<std::vector<ScopedShapedBuffer>> {
Expand Down Expand Up @@ -577,7 +612,7 @@ absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
}

absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
std::function<Executable*(int64_t)> executable_provider,
std::function<HloRunnerExecutableHandle*(int64_t)> executable_provider,
std::function<int64_t(int64_t)> argument_count_provider,
std::function<const Literal*(int64_t, int64_t)> argument_provider,
const ReplicatedExecuteOptions& options,
Expand Down Expand Up @@ -609,8 +644,11 @@ absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
for (const auto& arg : argument_buffer_slices[i]) {
TF_RET_CHECK(arg != nullptr);
}
TF_ASSIGN_OR_RETURN(
HloRunnerExecutable* const executable,
HloRunnerExecutable::FromHandle(*this, executable_provider(i)));
pool.Schedule([&, i] {
auto result = executable_provider(i)->ExecuteOnStream(
auto result = executable->executable()->ExecuteOnStream(
&service_run_options[i], argument_buffer_slices[i]);
absl::MutexLock lock(&mutex);
thread_results[i] = std::move(result);
Expand Down Expand Up @@ -640,14 +678,15 @@ absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
return ExecuteReplicated(std::move(module), options, &device_assignment);
}

absl::StatusOr<std::unique_ptr<Executable>> HloRunner::CreateExecutable(
std::unique_ptr<HloModule> module, bool run_hlo_passes) {
absl::StatusOr<std::unique_ptr<HloRunnerExecutableHandle>>
HloRunner::CreateExecutable(std::unique_ptr<HloModule> module,
bool run_hlo_passes) {
return CreateExecutableWithBufferAssignment(
std::move(module),
/*buffer_assignment_proto=*/nullptr, run_hlo_passes);
}

absl::StatusOr<std::unique_ptr<Executable>>
absl::StatusOr<std::unique_ptr<HloRunnerExecutableHandle>>
HloRunner::CreateExecutableWithBufferAssignment(
std::unique_ptr<HloModule> module,
const BufferAssignmentProto* buffer_assignment_proto, bool run_hlo_passes) {
Expand All @@ -665,15 +704,20 @@ HloRunner::CreateExecutableWithBufferAssignment(
}
auto module_group = std::make_unique<HloModuleGroup>(std::move(module));
TF_ASSIGN_OR_RETURN(
auto executables,
std::vector<std::unique_ptr<Executable>> executables,
backend().compiler()->Compile(std::move(module_group),
{{backend().default_stream_executor()}},
backend().memory_allocator()));
return std::move(executables[0]);
return std::make_unique<HloRunnerExecutable>(this,
std::move(executables[0]));
}
return backend().compiler()->RunBackendWithBufferAssignment(
std::move(module), buffer_assignment_proto,
backend().default_stream_executor(), backend().memory_allocator());

TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
backend().compiler()->RunBackendWithBufferAssignment(
std::move(module), buffer_assignment_proto,
backend().default_stream_executor(), backend().memory_allocator()));
return std::make_unique<HloRunnerExecutable>(this, std::move(executable));
}

ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
Expand Down
12 changes: 7 additions & 5 deletions xla/service/hlo_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class HloRunner : public HloRunnerInterface {
using HloRunnerInterface::ExecuteWithExecutable;

absl::StatusOr<Literal> ExecuteWithExecutable(
Executable* executable, absl::Span<const Literal* const> arguments,
HloRunnerExecutableHandle* executable,
absl::Span<const Literal* const> arguments,
ExecutionProfile* profile) override;

// As Execute(), but accepts and returns device buffers instead of host
Expand Down Expand Up @@ -134,10 +135,10 @@ class HloRunner : public HloRunnerInterface {

// Creates an executable object given an HLO module. If run_hlo_passes is
// true, the HLO passes will be run as part of compilation.
absl::StatusOr<std::unique_ptr<Executable>> CreateExecutable(
absl::StatusOr<std::unique_ptr<HloRunnerExecutableHandle>> CreateExecutable(
std::unique_ptr<HloModule> module, bool run_hlo_passes) override;

absl::StatusOr<std::unique_ptr<Executable>>
absl::StatusOr<std::unique_ptr<HloRunnerExecutableHandle>>
CreateExecutableWithBufferAssignment(
std::unique_ptr<HloModule> module,
const BufferAssignmentProto* /*buffer_assignment_proto*/,
Expand All @@ -162,7 +163,8 @@ class HloRunner : public HloRunnerInterface {
// Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes,
// since we've already compiled the Executable.
absl::StatusOr<std::vector<Literal>> ExecuteReplicated(
Executable* executable, const ReplicatedExecuteOptions& options,
HloRunnerExecutableHandle* executable,
const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr);

// Same as above, but with different reusable Executables. This may update the
Expand All @@ -171,7 +173,7 @@ class HloRunner : public HloRunnerInterface {
// Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes,
// since we've already compiled the Executable.
absl::StatusOr<std::vector<Literal>> ExecuteReplicated(
std::function<Executable*(int64_t)> executable_provider,
std::function<HloRunnerExecutableHandle*(int64_t)> executable_provider,
std::function<int64_t(int64_t)> argument_count_provider,
std::function<const Literal*(int64_t, int64_t)> argument_provider,
const ReplicatedExecuteOptions& options,
Expand Down
3 changes: 1 addition & 2 deletions xla/service/hlo_runner_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ limitations under the License.
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/literal.h"
#include "xla/service/executable.h"
#include "xla/service/hlo_module_config.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/errors.h"
Expand Down Expand Up @@ -130,7 +129,7 @@ absl::StatusOr<Literal> HloRunnerInterface::ExecuteWithBufferAssignment(
}

absl::StatusOr<Literal> HloRunnerInterface::ExecuteWithExecutable(
Executable* executable, absl::Span<const Literal> arguments,
HloRunnerExecutableHandle* executable, absl::Span<const Literal> arguments,
ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
auto argument_pointers = MakePointerVector<const Literal>(arguments);
Expand Down
67 changes: 56 additions & 11 deletions xla/service/hlo_runner_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/base/nullability.h"
#include "absl/log/die_if_null.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
Expand All @@ -32,13 +35,13 @@ limitations under the License.
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/literal.h"
#include "xla/service/computation_placer.h"
#include "xla/service/executable.h"
#include "xla/shape.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"

namespace xla {

class HloRunnerInterface;
class BufferAssignmentProto;

// Tags to identify particular properties of a HloRunnerInterface
Expand Down Expand Up @@ -82,6 +85,45 @@ class HloRunnerPropertyTag final {
HloRunnerPropertyTag() = default;
};

// Runner implementations only support the execution of executables that were
// created by the runner. We use the this class to represent these executables
// when they leave the runner without exposing any details of the underlying
// implementation.
class HloRunnerExecutableHandle {
public:
virtual ~HloRunnerExecutableHandle() = default;

protected:
explicit HloRunnerExecutableHandle(
absl::Nonnull<const HloRunnerInterface*> creator)
: creator_(ABSL_DIE_IF_NULL(creator)) {}
// Cannot be moved or copied.
HloRunnerExecutableHandle(const HloRunnerExecutableHandle&) = default;
HloRunnerExecutableHandle& operator=(const HloRunnerExecutableHandle&) =
default;

template <typename T>
static absl::StatusOr<absl::Nonnull<T*>> FromHandle(
const HloRunnerInterface& runner,
absl::Nonnull<HloRunnerExecutableHandle*> const handle) {
static_assert(std::is_base_of_v<HloRunnerExecutableHandle, T>,
"FromHandle must be instantiated with a subclass of "
"HloRunnerExecutableHandle.");
if (handle->creator_ != &runner) {
return absl::InvalidArgumentError(
"Executable was not created by this runner.");
}

T* const executable = tensorflow::down_cast<T*>(handle);
if (executable == nullptr) {
return absl::InvalidArgumentError("Invalid executable handle.");
}
return executable;
}

const HloRunnerInterface* const creator_;
};

// A base class for running an HloModule. This executes the given HloModule on a
// certain backend directly without using the client interface. HloModule can be
// explicitly built, or loaded from a serialization file (e.g., hlo proto
Expand Down Expand Up @@ -157,16 +199,17 @@ class HloRunnerInterface {
const std::string& filename, const DebugOptions& debug_options,
const HloParserOptions& options = HloParserOptions());

// Creates an executable object given an HLO module. If run_hlo_passes is
// true, the HLO passes will be run as part of compilation.
virtual absl::StatusOr<std::unique_ptr<Executable>> CreateExecutable(
std::unique_ptr<HloModule> module, bool run_hlo_passes) = 0;
// Creates a runner-internal executable object given an HLO module and returns
// its handle. If run_hlo_passes is true, the HLO passes will be run as part
// of compilation.
virtual absl::StatusOr<std::unique_ptr<HloRunnerExecutableHandle>>
CreateExecutable(std::unique_ptr<HloModule> module, bool run_hlo_passes) = 0;

// Same as above, except it takes buffer assignment as input.
// Note: The default implementation of the API here does not utilize the given
// buffer assignment. A derived runner interface is expected to override the
// following method to achieve this functionality.
virtual absl::StatusOr<std::unique_ptr<Executable>>
virtual absl::StatusOr<std::unique_ptr<HloRunnerExecutableHandle>>
CreateExecutableWithBufferAssignment(
std::unique_ptr<HloModule> module,
const BufferAssignmentProto* /*buffer_assignment_proto*/,
Expand Down Expand Up @@ -226,16 +269,18 @@ class HloRunnerInterface {

// Same as 3 Execute methods above, but with Executable as input.
absl::StatusOr<Literal> ExecuteWithExecutable(
Executable* executable, absl::Span<const Literal> arguments,
ExecutionProfile* profile = nullptr);
HloRunnerExecutableHandle* executable,
absl::Span<const Literal> arguments, ExecutionProfile* profile = nullptr);

absl::StatusOr<Literal> ExecuteWithExecutable(
Executable* executable, absl::Span<const Literal* const> arguments) {
HloRunnerExecutableHandle* executable,
absl::Span<const Literal* const> arguments) {
return ExecuteWithExecutable(executable, arguments, nullptr);
}

virtual absl::StatusOr<Literal> ExecuteWithExecutable(
Executable* executable, absl::Span<const Literal* const> arguments,
HloRunnerExecutableHandle* executable,
absl::Span<const Literal* const> arguments,
ExecutionProfile* profile) = 0;

// Executes a given HLO module into a set of replicas, and returns a map
Expand All @@ -253,7 +298,7 @@ class HloRunnerInterface {
DeviceAssignment* device_assignment) = 0;

virtual absl::StatusOr<std::vector<Literal>> ExecuteReplicated(
std::function<Executable*(int64_t)> executable_provider,
std::function<HloRunnerExecutableHandle*(int64_t)> executable_provider,
std::function<int64_t(int64_t)> argument_count_provider,
std::function<const Literal*(int64_t, int64_t)> argument_provider,
const ReplicatedExecuteOptions& options,
Expand Down
Loading

0 comments on commit 356965b

Please sign in to comment.