Skip to content

Commit

Permalink
Remove the is_last_transfer parameter
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 722035474
  • Loading branch information
Google-ML-Automation committed Feb 1, 2025
1 parent cbfada0 commit e990769
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 29 deletions.
2 changes: 0 additions & 2 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -890,8 +890,6 @@ xla_cc_test(
"@com_google_googletest//:gtest_main",
"@llvm-project//mlir:IR",
"@stablehlo//:version",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)

Expand Down
1 change: 0 additions & 1 deletion xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,6 @@ struct PJRT_AsyncHostToDeviceTransferManager_TransferData_Args {
const void* data;
int64_t offset;
int64_t transfer_size;
bool is_last_transfer;
PJRT_Event* done_with_h2d_transfer; // out
};
PJRT_DEFINE_STRUCT_TRAITS(
Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData(
PJRT_RETURN_IF_ERROR(
args->transfer_manager->transfer_manager->TransferRawDataToSubBuffer(
args->buffer_index, args->data, args->offset, args->transfer_size,
args->is_last_transfer, std::move(on_done_with_d2h_transfer)));
std::move(on_done_with_d2h_transfer)));
args->done_with_h2d_transfer =
new PJRT_Event{xla::PjRtFuture<>(std::move(promise))};
return nullptr;
Expand Down
24 changes: 13 additions & 11 deletions xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -959,37 +959,36 @@ AbstractAsyncHostToHostMemoryTransferManager::TransferLiteralToBuffer(
[literal](void* b, int64_t size) {
PackOrCopy(literal.shape().element_type(), literal, b, size);
},
/*is_last_transfer=*/true, std::move(on_done));
std::move(on_done));
}

absl::Status
AbstractAsyncHostToHostMemoryTransferManager::TransferRawDataToBuffer(
int buffer_index, absl::string_view data,
absl::AnyInvocable<void() &&> on_done) {
return TransferRawDataToSubBuffer(
buffer_index, data.data(), /*offset=*/0, data.size(),
/*is_last_transfer=*/true, std::move(on_done));
return TransferRawDataToSubBuffer(buffer_index, data.data(), /*offset=*/0,
data.size(), std::move(on_done));
}

// The definition events of `device_buffers_` must be ready before calling this
// function.
absl::Status
AbstractAsyncHostToHostMemoryTransferManager::TransferRawDataToSubBuffer(
int buffer_index, const void* data, int64_t offset, int64_t transfer_size,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) {
absl::AnyInvocable<void() &&> on_done) {
return FillRawDataToSubBuffer(
buffer_index,
[offset, data, transfer_size](void* b, int64_t size) {
std::memcpy(reinterpret_cast<char*>(b) + offset, data, transfer_size);
},
is_last_transfer, std::move(on_done));
std::move(on_done));
}

absl::Status
AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer(
int buffer_index,
absl::AnyInvocable<void(void* data, int64_t size)> fill_fn,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) {
absl::AnyInvocable<void() &&> on_done) {
{
// We release the lock when out of scope because
// `async_work_runner_->Schedule` might sometimes run the closure in this
Expand All @@ -1005,17 +1004,14 @@ AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer(

CHECK(async_work_runner_ != nullptr);
async_work_runner_->Schedule([this, fill_fn = std::move(fill_fn),
is_last_transfer, on_done = std::move(on_done),
on_done = std::move(on_done),
buffer_index]() mutable -> void {
tsl::RCReference<tsl::AsyncValue> event;
{
absl::MutexLock l(&mu_);
const auto& b = device_buffers_[buffer_index]->Buffers()[0];
CHECK(b.IsConcrete());
fill_fn(reinterpret_cast<char*>(b->data()), b->size());
if (is_last_transfer) {
last_transfer_finished_[buffer_index] = true;
}
--buffer_transfers_in_flight_[buffer_index];
--transfers_in_flight_;
if (buffer_transfers_in_flight_[buffer_index] == 0 &&
Expand All @@ -1033,6 +1029,12 @@ AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer(
return absl::OkStatus();
}

void AbstractAsyncHostToHostMemoryTransferManager::MarkBufferCompletion(
int buffer_index) {
absl::MutexLock l(&mu_);
last_transfer_finished_[buffer_index] = true;
}

void AbstractAsyncHostToHostMemoryTransferManager::SetBufferError(
int buffer_index, absl::Status error) {
absl::MutexLock l(&mu_);
Expand Down
6 changes: 4 additions & 2 deletions xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,9 @@ class AbstractAsyncHostToHostMemoryTransferManager

absl::Status TransferRawDataToSubBuffer(
int buffer_index, const void* data, int64_t offset, int64_t transfer_size,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) override;
absl::AnyInvocable<void() &&> on_done) override;

void MarkBufferCompletion(int buffer_index) override;

void SetBufferError(int buffer_index, absl::Status error) override;

Expand Down Expand Up @@ -373,7 +375,7 @@ class AbstractAsyncHostToHostMemoryTransferManager
absl::Status FillRawDataToSubBuffer(
int buffer_index,
absl::AnyInvocable<void(void* data, int64_t size)> fill_fn,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done);
absl::AnyInvocable<void() &&> on_done);

mutable absl::Mutex mu_;
// The number of transfers that are currently in flight.
Expand Down
6 changes: 3 additions & 3 deletions xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -776,13 +776,12 @@ class PjRtCApiAsyncHostToDeviceTransferManager
int buffer_index, absl::string_view data,
absl::AnyInvocable<void() &&> on_done) override {
return TransferRawDataToSubBuffer(buffer_index, data.data(), 0, data.size(),
/*is_last_transfer=*/true,
std::move(on_done));
}

absl::Status TransferRawDataToSubBuffer(
int buffer_index, const void* data, int64_t offset, int64_t transfer_size,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) override {
absl::AnyInvocable<void() &&> on_done) override {
PJRT_AsyncHostToDeviceTransferManager_TransferData_Args args;
args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_TransferData_Args_STRUCT_SIZE;
Expand All @@ -792,7 +791,6 @@ class PjRtCApiAsyncHostToDeviceTransferManager
args.data = data;
args.offset = offset;
args.transfer_size = transfer_size;
args.is_last_transfer = is_last_transfer;
const PJRT_Api* api = c_client_->pjrt_c_api();
RETURN_STATUS_IF_PJRT_ERROR(
api->PJRT_AsyncHostToDeviceTransferManager_TransferData(&args), api);
Expand All @@ -804,6 +802,8 @@ class PjRtCApiAsyncHostToDeviceTransferManager
return absl::OkStatus();
}

void MarkBufferCompletion(int buffer_index) override {}

void SetBufferError(int buffer_index, absl::Status error) override {
PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args args;
args.struct_size =
Expand Down
18 changes: 9 additions & 9 deletions xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -732,17 +732,17 @@ class PjRtClient {
// Transfers 'data' into a sub-buffer of buffer_index starting at offset, of
// length transfer_size. 'data' must be already laid out in the correct
// on-device format, for example returned by a call to
// buffer->CopyRawToHost. If is_last_transfer is false then the buffer
// remains unavailable to consumers after the transfer completes. If
// is_last_transfer is true then the buffer becomes available to consumers
// after the transfer completes, and no transfer calls (or SetBufferError
// calls) into buffer_index can be made after this call. on_done is called
// when the transfer is complete but before the buffers are made available
// to their consumers. 'data' must remain in scope until on_done is called.
// buffer->CopyRawToHost. on_done is called when the transfer is complete
// but before the buffers are made available to their consumers.
// 'data' must remain in scope until on_done is called.
virtual absl::Status TransferRawDataToSubBuffer(
int buffer_index, const void* data, int64_t offset,
int64_t transfer_size, bool is_last_transfer,
absl::AnyInvocable<void() &&> on_done) = 0;
int64_t transfer_size, absl::AnyInvocable<void() &&> on_done) = 0;

// Indicates that data transfer for the buffer `buffer_index` is complete,
// thus the buffer becomes available to consumers. No transfer calls (or
// SetBufferError calls) into `buffer_index` can be made after this call.
virtual void MarkBufferCompletion(int buffer_index) = 0;

// Indicates that a specific buffer should result in an error status. No
// transfer calls (or further SetBufferError calls) into buffer_index can
Expand Down

0 comments on commit e990769

Please sign in to comment.