Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce OpaqueExecutable and conversion functions. #22057

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4640,7 +4640,6 @@ cc_library(
hdrs = ["hlo_runner_interface.h"],
deps = [
":computation_placer",
":executable",
":hlo_module_config",
"//xla:literal",
"//xla:shape_util",
Expand All @@ -4653,6 +4652,8 @@ cc_library(
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:die_if_null",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
Expand Down
103 changes: 80 additions & 23 deletions xla/service/hlo_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,33 @@ limitations under the License.

namespace xla {

namespace {
class HloRunnerExecutable : public OpaqueExecutable {
public:
HloRunnerExecutable(absl::Nonnull<const HloRunner*> creator,
std::unique_ptr<Executable> executable)
: OpaqueExecutable(creator), executable_(std::move(executable)) {}

Executable* executable() const { return executable_.get(); }
std::unique_ptr<Executable> MoveExecutable() {
return std::move(executable_);
}

static absl::StatusOr<HloRunnerExecutable*> TryUnwrap(
const HloRunner& runner, absl::Nonnull<OpaqueExecutable*> const wrapped) {
return OpaqueExecutable::TryUnwrap<HloRunnerExecutable>(runner, wrapped);
}
static absl::StatusOr<const HloRunnerExecutable*> TryUnwrap(
const HloRunner& runner,
absl::Nonnull<const OpaqueExecutable*> const wrapped) {
return OpaqueExecutable::TryUnwrap<HloRunnerExecutable>(runner, wrapped);
}

private:
std::unique_ptr<Executable> executable_;
};
} // namespace

HloRunner::HloRunner(se::Platform* platform, int intra_op_parallelism_threads) {
BackendOptions backend_options;
backend_options.set_platform(platform);
Expand Down Expand Up @@ -216,13 +243,15 @@ absl::StatusOr<Literal> HloRunner::ExecuteWithBufferAssignment(
absl::StatusOr<Literal> HloRunner::ExecuteWithExecutable(
OpaqueExecutable* executable, absl::Span<const Literal* const> arguments,
ExecutionProfile* profile) {
entry_computation_layout_ =
&(executable->module().entry_computation_layout());
TF_ASSIGN_OR_RETURN(HloRunnerExecutable* const hlo_runner_executable,
HloRunnerExecutable::TryUnwrap(*this, executable));
entry_computation_layout_ = &(
hlo_runner_executable->executable()->module().entry_computation_layout());
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
TransferLiteralsToDevice(arguments));
TF_ASSIGN_OR_RETURN(ExecutionOutput result,
ExecuteWithDeviceBuffers(
/*executable=*/executable,
/*executable=*/hlo_runner_executable,
/*arguments=*/argument_buffers,
/*profile=*/profile));
return TransferLiteralFromDevice(result.Result());
Expand Down Expand Up @@ -316,19 +345,26 @@ absl::StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<OpaqueExecutable> executable,
CreateExecutable(std::move(module), run_hlo_passes));
return ExecuteWithDeviceBuffers(executable.get(), arguments, profile);
TF_ASSIGN_OR_RETURN(HloRunnerExecutable* const hlo_runner_executable,
HloRunnerExecutable::TryUnwrap(*this, executable.get()));
return ExecuteWithDeviceBuffers(hlo_runner_executable, arguments, profile);
}

absl::StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
OpaqueExecutable* executable,
absl::Span<ScopedShapedBuffer const> arguments, ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(HloRunnerExecutable* const hlo_runner_executable,
HloRunnerExecutable::TryUnwrap(*this, executable));
std::vector<ExecutionInput> execution_arguments =
ExecutionInputsFromScopedShapedBuffers(
arguments, executable->module().input_output_alias_config(),
arguments,
hlo_runner_executable->executable()
->module()
.input_output_alias_config(),
backend().default_stream_executor()->device_ordinal(),
GetAllocator());
return ExecuteWithExecutionInputs(executable, std::move(execution_arguments),
profile);
return ExecuteWithExecutionInputs(hlo_runner_executable->executable(),
std::move(execution_arguments), profile);
}

absl::StatusOr<ExecutionOutput> HloRunner::ExecuteWithMovedDeviceBuffers(
Expand All @@ -350,8 +386,10 @@ HloRunner::ExecuteWithMovedDeviceBuffersAndBufferAssignment(
std::unique_ptr<OpaqueExecutable> executable,
CreateExecutableWithBufferAssignment(
std::move(module), buffer_assignment_proto, run_hlo_passes));
return ExecuteWithMovedDeviceBuffers(executable.get(), std::move(arguments),
profile);
TF_ASSIGN_OR_RETURN(HloRunnerExecutable* const hlo_runner_executable,
HloRunnerExecutable::TryUnwrap(*this, executable.get()));
return ExecuteWithMovedDeviceBuffers(hlo_runner_executable->executable(),
std::move(arguments), profile);
}

absl::StatusOr<ExecutionOutput> HloRunner::ExecuteWithMovedDeviceBuffers(
Expand Down Expand Up @@ -559,8 +597,11 @@ absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicatedImpl(
absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
OpaqueExecutable* executable, const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment, ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(HloRunnerExecutable* const wrapped_executable,
HloRunnerExecutable::TryUnwrap(*this, executable));
return ExecuteReplicatedImpl(
[&](const std::vector<ServiceExecutableRunOptions>& service_run_options,
[&, executable = wrapped_executable->executable()](
const std::vector<ServiceExecutableRunOptions>& service_run_options,
const std::vector<absl::Span<const ShapedBuffer* const>>&
argument_buffer_slices)
-> absl::StatusOr<std::vector<ScopedShapedBuffer>> {
Expand Down Expand Up @@ -637,8 +678,11 @@ absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
for (const auto& arg : argument_buffer_slices[i]) {
TF_RET_CHECK(arg != nullptr);
}
pool.Schedule([&, i] {
auto result = executable_provider(i)->ExecuteOnStream(
TF_ASSIGN_OR_RETURN(
HloRunnerExecutable* const executable,
HloRunnerExecutable::TryUnwrap(*this, executable_provider(i)));
pool.Schedule([&, i, executable] {
auto result = executable->executable()->ExecuteOnStream(
&service_run_options[i], argument_buffer_slices[i]);
absl::MutexLock lock(&mutex);
thread_results[i] = std::move(result);
Expand Down Expand Up @@ -697,11 +741,16 @@ HloRunner::CreateExecutableWithBufferAssignment(
backend().compiler()->Compile(std::move(module_group),
{{backend().default_stream_executor()}},
backend().memory_allocator()));
return std::move(executables[0]);
return std::make_unique<HloRunnerExecutable>(this,
std::move(executables[0]));
}
return backend().compiler()->RunBackendWithBufferAssignment(
std::move(module), buffer_assignment_proto,
backend().default_stream_executor(), backend().memory_allocator());

TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
backend().compiler()->RunBackendWithBufferAssignment(
std::move(module), buffer_assignment_proto,
backend().default_stream_executor(), backend().memory_allocator()));
return std::make_unique<HloRunnerExecutable>(this, std::move(executable));
}

ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
Expand Down Expand Up @@ -754,30 +803,38 @@ bool HloRunner::HasProperty(const HloRunnerPropertyTag::Type tag) const {

absl::StatusOr<Executable*> HloRunner::ExecutableFromWrapped(
const OpaqueExecutable* wrapped) const {
return const_cast<Executable*>(wrapped);
TF_ASSIGN_OR_RETURN(const HloRunnerExecutable* const hlo_runner_executable,
HloRunnerExecutable::TryUnwrap(*this, wrapped));
return hlo_runner_executable->executable();
}

absl::StatusOr<std::unique_ptr<Executable>> HloRunner::ExecutableFromWrapped(
std::unique_ptr<OpaqueExecutable> wrapped) const {
return std::unique_ptr<Executable>(wrapped.release());
TF_ASSIGN_OR_RETURN(HloRunnerExecutable* const hlo_runner_executable,
HloRunnerExecutable::TryUnwrap(*this, wrapped.get()));
return hlo_runner_executable->MoveExecutable();
}

std::unique_ptr<OpaqueExecutable> HloRunner::WrapExecutable(
std::unique_ptr<Executable> executable) const {
return std::unique_ptr<OpaqueExecutable>(executable.release());
return std::make_unique<HloRunnerExecutable>(this, std::move(executable));
}

absl::StatusOr<absl::Nonnull<const HloModule*>> HloRunner::HloModuleFromWrapped(
const OpaqueExecutable* wrapped) const {
if (wrapped->has_module()) {
return &wrapped->module();
TF_ASSIGN_OR_RETURN(const HloRunnerExecutable* const hlo_runner_executable,
HloRunnerExecutable::TryUnwrap(*this, wrapped));
if (!hlo_runner_executable->executable()->has_module()) {
return absl::NotFoundError("Executable has no module.");
}
return absl::NotFoundError("OpaqueExecutable does not contain an HloModule.");
return &hlo_runner_executable->executable()->module();
}

absl::StatusOr<absl::Nonnull<const HloProto*>> HloRunner::HloProtoFromWrapped(
const OpaqueExecutable* wrapped) const {
return wrapped->hlo_proto();
TF_ASSIGN_OR_RETURN(const HloRunnerExecutable* const hlo_runner_executable,
HloRunnerExecutable::TryUnwrap(*this, wrapped));
return hlo_runner_executable->executable()->hlo_proto();
}

} // namespace xla
71 changes: 69 additions & 2 deletions xla/service/hlo_runner_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ limitations under the License.
#include <vector>

#include "absl/base/nullability.h"
#include "absl/log/die_if_null.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
Expand All @@ -33,7 +35,6 @@ limitations under the License.
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/literal.h"
#include "xla/service/computation_placer.h"
#include "xla/service/executable.h"
#include "xla/shape.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -88,7 +89,73 @@ class HloRunnerPropertyTag final {
// created by the same runner. We use the this class to represent these
// executables when they leave the runner without exposing any details of the
// underlying implementation. See go/xla-opaque-executable for more details.
using OpaqueExecutable = Executable;
class OpaqueExecutable {
public:
virtual ~OpaqueExecutable() = default;

// !!! STOP !!!
// Before adding any methods to this class, please consider if they could be
// added to the HloRunnerInterface instead.
//
// Adding methods to this class imposes a burden on runners as they must
// implement and support any/all types used in the signature. The runner
// itself should serve as the only means of accessing information about the
// executable, since only the runner is capable of unwrapping the executable.
//
// E.g. you might be inclined to add a method to this class that returns a
// HloModule. DON'T. Not all executables may have a HloModule, while some may
// even have multiple. The runner interface has a HloModuleFromWrapped method
// that has the semantics of returning the first HloModule in the executable
// if there are multiple, or the sole HloModule if there is only one.
// !!! STOP !!!

protected:
explicit OpaqueExecutable(absl::Nonnull<const HloRunnerInterface*> creator)
: creator_(ABSL_DIE_IF_NULL(creator)) {}
// Cannot be moved or copied.
OpaqueExecutable(const OpaqueExecutable&) = default;
OpaqueExecutable& operator=(const OpaqueExecutable&) = default;

template <typename T>
static absl::StatusOr<absl::Nonnull<T*>> TryUnwrap(
const HloRunnerInterface& runner,
absl::Nonnull<OpaqueExecutable*> const wrapped) {
static_assert(
std::is_base_of_v<OpaqueExecutable, T>,
"TryUnwrap must be used with a subclass of OpaqueExecutable.");
if (wrapped->creator_ != &runner) {
return absl::InvalidArgumentError(
"Executable was not created by this runner.");
}

if (T* const executable = tensorflow::down_cast<T*>(wrapped);
executable != nullptr) {
return executable;
}
return absl::InvalidArgumentError("Invalid opaque executable.");
}

template <typename T>
static absl::StatusOr<absl::Nonnull<const T*>> TryUnwrap(
const HloRunnerInterface& runner,
absl::Nonnull<const OpaqueExecutable*> const wrapped) {
static_assert(
std::is_base_of_v<OpaqueExecutable, T>,
"TryUnwrap must be used with a subclass of OpaqueExecutable.");
if (wrapped->creator_ != &runner) {
return absl::InvalidArgumentError(
"Executable was not created by this runner.");
}

if (const T* const executable = tensorflow::down_cast<const T*>(wrapped);
executable != nullptr) {
return executable;
}
return absl::InvalidArgumentError("Invalid opaque executable.");
}

const HloRunnerInterface* const creator_;
};

// A base class for running an HloModule. This executes the given HloModule on a
// certain backend directly without using the client interface. HloModule can be
Expand Down
Loading
Loading