Skip to content

Commit

Permalink
Remove the is_last_transfer parameter
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 722035474
  • Loading branch information
Google-ML-Automation committed Feb 6, 2025
1 parent 19dd926 commit 317a3c9
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 53 deletions.
2 changes: 0 additions & 2 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -868,8 +868,6 @@ xla_cc_test(
"@com_google_googletest//:gtest_main",
"@llvm-project//mlir:IR",
"@stablehlo//:version",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)

Expand Down
1 change: 0 additions & 1 deletion xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,6 @@ struct PJRT_AsyncHostToDeviceTransferManager_TransferData_Args {
const void* data;
int64_t offset;
int64_t transfer_size;
bool is_last_transfer;
PJRT_Event* done_with_h2d_transfer; // out
};
PJRT_DEFINE_STRUCT_TRAITS(
Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData(
PJRT_RETURN_IF_ERROR(
args->transfer_manager->transfer_manager->TransferRawDataToSubBuffer(
args->buffer_index, args->data, args->offset, args->transfer_size,
args->is_last_transfer, std::move(on_done_with_d2h_transfer)));
std::move(on_done_with_d2h_transfer)));
args->done_with_h2d_transfer =
new PJRT_Event{xla::PjRtFuture<>(std::move(promise))};
return nullptr;
Expand Down
24 changes: 13 additions & 11 deletions xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -959,37 +959,36 @@ AbstractAsyncHostToHostMemoryTransferManager::TransferLiteralToBuffer(
[literal](void* b, int64_t size) {
PackOrCopy(literal.shape().element_type(), literal, b, size);
},
/*is_last_transfer=*/true, std::move(on_done));
std::move(on_done));
}

absl::Status
AbstractAsyncHostToHostMemoryTransferManager::TransferRawDataToBuffer(
int buffer_index, absl::string_view data,
absl::AnyInvocable<void() &&> on_done) {
return TransferRawDataToSubBuffer(
buffer_index, data.data(), /*offset=*/0, data.size(),
/*is_last_transfer=*/true, std::move(on_done));
return TransferRawDataToSubBuffer(buffer_index, data.data(), /*offset=*/0,
data.size(), std::move(on_done));
}

// The definition events of `device_buffers_` must be ready before calling this
// function.
absl::Status
AbstractAsyncHostToHostMemoryTransferManager::TransferRawDataToSubBuffer(
int buffer_index, const void* data, int64_t offset, int64_t transfer_size,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) {
absl::AnyInvocable<void() &&> on_done) {
return FillRawDataToSubBuffer(
buffer_index,
[offset, data, transfer_size](void* b, int64_t size) {
std::memcpy(reinterpret_cast<char*>(b) + offset, data, transfer_size);
},
is_last_transfer, std::move(on_done));
std::move(on_done));
}

absl::Status
AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer(
int buffer_index,
absl::AnyInvocable<void(void* data, int64_t size)> fill_fn,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) {
absl::AnyInvocable<void() &&> on_done) {
{
// We release the lock when out of scope because
// `async_work_runner_->Schedule` might sometimes run the closure in this
Expand All @@ -1005,17 +1004,14 @@ AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer(

CHECK(async_work_runner_ != nullptr);
async_work_runner_->Schedule([this, fill_fn = std::move(fill_fn),
is_last_transfer, on_done = std::move(on_done),
on_done = std::move(on_done),
buffer_index]() mutable -> void {
tsl::RCReference<tsl::AsyncValue> event;
{
absl::MutexLock l(&mu_);
const auto& b = device_buffers_[buffer_index]->Buffers()[0];
CHECK(b.IsConcrete());
fill_fn(reinterpret_cast<char*>(b->data()), b->size());
if (is_last_transfer) {
last_transfer_finished_[buffer_index] = true;
}
--buffer_transfers_in_flight_[buffer_index];
--transfers_in_flight_;
if (buffer_transfers_in_flight_[buffer_index] == 0 &&
Expand All @@ -1033,6 +1029,12 @@ AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer(
return absl::OkStatus();
}

void AbstractAsyncHostToHostMemoryTransferManager::MarkBufferCompletion(
int buffer_index) {
absl::MutexLock l(&mu_);
last_transfer_finished_[buffer_index] = true;
}

void AbstractAsyncHostToHostMemoryTransferManager::SetBufferError(
int buffer_index, absl::Status error) {
absl::MutexLock l(&mu_);
Expand Down
6 changes: 4 additions & 2 deletions xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,9 @@ class AbstractAsyncHostToHostMemoryTransferManager

absl::Status TransferRawDataToSubBuffer(
int buffer_index, const void* data, int64_t offset, int64_t transfer_size,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) override;
absl::AnyInvocable<void() &&> on_done) override;

void MarkBufferCompletion(int buffer_index) override;

void SetBufferError(int buffer_index, absl::Status error) override;

Expand Down Expand Up @@ -373,7 +375,7 @@ class AbstractAsyncHostToHostMemoryTransferManager
absl::Status FillRawDataToSubBuffer(
int buffer_index,
absl::AnyInvocable<void(void* data, int64_t size)> fill_fn,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done);
absl::AnyInvocable<void() &&> on_done);

mutable absl::Mutex mu_;
// The number of transfers that are currently in flight.
Expand Down
34 changes: 10 additions & 24 deletions xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,7 @@ class AsyncHostToDeviceTransferManager
// Call cleanup once the transfer has finished on the stream.
auto cleanup = [this, buffer_index, stream, on_done = std::move(on_done),
event = std::move(event).value()]() mutable {
CleanUp(buffer_index, std::move(event), stream,
/*is_last_transfer=*/true, std::move(on_done));
CleanUp(buffer_index, std::move(event), stream, std::move(on_done));
};
auto status = stream->DoHostCallback(std::move(cleanup));
if (!status.ok()) {
Expand All @@ -325,13 +324,12 @@ class AsyncHostToDeviceTransferManager
absl::AnyInvocable<void() &&> on_done) override {
return TransferRawDataToSubBuffer(buffer_index, data.data(),
/*offset=*/0, data.size(),
/*is_last_transfer=*/true,
std::move(on_done));
}

absl::Status TransferRawDataToSubBuffer(
int buffer_index, const void* data, int64_t offset, int64_t transfer_size,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) override {
absl::AnyInvocable<void() &&> on_done) override {
auto* stream = device_->local_device_state()->host_to_device_stream();

auto* client =
Expand Down Expand Up @@ -363,9 +361,6 @@ class AsyncHostToDeviceTransferManager
"already been fully transferred",
buffer_index);
}
if (is_last_transfer) {
last_transfer_started_[buffer_index] = true;
}
DCHECK(buffer_ptrs_[buffer_index]);
if (buffer_ptrs_[buffer_index]->device_memory().empty()) {
return InvalidArgument(
Expand Down Expand Up @@ -418,14 +413,18 @@ class AsyncHostToDeviceTransferManager
event.value());

auto cleanup = [this, buffer_index, event = std::move(event).value(),
stream, is_last_transfer, on_done = std::move(on_done),
stream, on_done = std::move(on_done),
staging_buffer = std::move(staging_buffer)]() mutable {
CleanUp(buffer_index, std::move(event), stream, is_last_transfer,
std::move(on_done));
CleanUp(buffer_index, std::move(event), stream, std::move(on_done));
};
return stream->DoHostCallback(std::move(cleanup));
}

void MarkBufferCompletion(int buffer_index) override {
absl::MutexLock l(&mu_);
last_transfer_started_[buffer_index] = true;
}

void SetBufferError(int buffer_index, absl::Status error) override {
{
absl::MutexLock l(&mu_);
Expand Down Expand Up @@ -471,24 +470,11 @@ class AsyncHostToDeviceTransferManager
PjRtStreamExecutorDevice* device_; // not owned.

void CleanUp(int buffer_index, EventPool::Handle event, se::Stream* stream,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) {
absl::AnyInvocable<void() &&> on_done) {
{
absl::MutexLock l(&mu_);

CHECK_GT(transfers_in_flight_, 0);
--transfers_in_flight_;
if (is_last_transfer) {
// Drop our reference to the TrackedDeviceBuffer for this buffer.
CHECK(buffer_ptrs_[buffer_index]);
buffer_ptrs_[buffer_index] = nullptr;
CHECK_GT(remaining_buffer_count_, 0);
--remaining_buffer_count_;
definition_events_[buffer_index]->SetSequencingEvent(std::move(event),
stream);
if (remaining_buffer_count_ == 0) {
VLOG(1) << "TransferLiteralToBuffer for all buffers is done.";
}
}
}

// Call on_done after finishing all housekeeping and releasing the lock.
Expand Down
6 changes: 3 additions & 3 deletions xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -752,13 +752,12 @@ class PjRtCApiAsyncHostToDeviceTransferManager
int buffer_index, absl::string_view data,
absl::AnyInvocable<void() &&> on_done) override {
return TransferRawDataToSubBuffer(buffer_index, data.data(), 0, data.size(),
/*is_last_transfer=*/true,
std::move(on_done));
}

absl::Status TransferRawDataToSubBuffer(
int buffer_index, const void* data, int64_t offset, int64_t transfer_size,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) override {
absl::AnyInvocable<void() &&> on_done) override {
PJRT_AsyncHostToDeviceTransferManager_TransferData_Args args;
args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_TransferData_Args_STRUCT_SIZE;
Expand All @@ -768,7 +767,6 @@ class PjRtCApiAsyncHostToDeviceTransferManager
args.data = data;
args.offset = offset;
args.transfer_size = transfer_size;
args.is_last_transfer = is_last_transfer;
const PJRT_Api* api = c_client_->pjrt_c_api();
RETURN_STATUS_IF_PJRT_ERROR(
api->PJRT_AsyncHostToDeviceTransferManager_TransferData(&args), api);
Expand All @@ -780,6 +778,8 @@ class PjRtCApiAsyncHostToDeviceTransferManager
return absl::OkStatus();
}

void MarkBufferCompletion(int buffer_index) override {}

void SetBufferError(int buffer_index, absl::Status error) override {
PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args args;
args.struct_size =
Expand Down
18 changes: 9 additions & 9 deletions xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -732,17 +732,17 @@ class PjRtClient {
// Transfers 'data' into a sub-buffer of buffer_index starting at offset, of
// length transfer_size. 'data' must be already laid out in the correct
// on-device format, for example returned by a call to
// buffer->CopyRawToHost. If is_last_transfer is false then the buffer
// remains unavailable to consumers after the transfer completes. If
// is_last_transfer is true then the buffer becomes available to consumers
// after the transfer completes, and no transfer calls (or SetBufferError
// calls) into buffer_index can be made after this call. on_done is called
// when the transfer is complete but before the buffers are made available
// to their consumers. 'data' must remain in scope until on_done is called.
// buffer->CopyRawToHost. on_done is called when the transfer is complete
// but before the buffers are made available to their consumers.
// 'data' must remain in scope until on_done is called.
virtual absl::Status TransferRawDataToSubBuffer(
int buffer_index, const void* data, int64_t offset,
int64_t transfer_size, bool is_last_transfer,
absl::AnyInvocable<void() &&> on_done) = 0;
int64_t transfer_size, absl::AnyInvocable<void() &&> on_done) = 0;

// Indicates that data transfer for the buffer `buffer_index` is complete,
// thus the buffer becomes available to consumers. No transfer calls (or
// SetBufferError calls) into `buffer_index` can be made after this call.
virtual void MarkBufferCompletion(int buffer_index) = 0;

// Indicates that a specific buffer should result in an error status. No
// transfer calls (or further SetBufferError calls) into buffer_index can
Expand Down

0 comments on commit 317a3c9

Please sign in to comment.