diff --git a/xla/pjrt/BUILD b/xla/pjrt/BUILD index 5502da86e87ab5..e9efdf003e731c 100644 --- a/xla/pjrt/BUILD +++ b/xla/pjrt/BUILD @@ -890,8 +890,6 @@ xla_cc_test( "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", "@stablehlo//:version", - "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test", ], ) diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index 669ba140945886..947c9b40444559 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -674,7 +674,6 @@ struct PJRT_AsyncHostToDeviceTransferManager_TransferData_Args { const void* data; int64_t offset; int64_t transfer_size; - bool is_last_transfer; PJRT_Event* done_with_h2d_transfer; // out }; PJRT_DEFINE_STRUCT_TRAITS( diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index eafb33be48a8a1..d270a53d8a4649 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -662,7 +662,7 @@ PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData( PJRT_RETURN_IF_ERROR( args->transfer_manager->transfer_manager->TransferRawDataToSubBuffer( args->buffer_index, args->data, args->offset, args->transfer_size, - args->is_last_transfer, std::move(on_done_with_d2h_transfer))); + std::move(on_done_with_d2h_transfer))); args->done_with_h2d_transfer = new PJRT_Event{xla::PjRtFuture<>(std::move(promise))}; return nullptr; diff --git a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc index 5a468134f31421..78184511a4f9ab 100644 --- a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc +++ b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc @@ -959,16 +959,15 @@ AbstractAsyncHostToHostMemoryTransferManager::TransferLiteralToBuffer( [literal](void* b, int64_t size) { PackOrCopy(literal.shape().element_type(), literal, b, size); }, - /*is_last_transfer=*/true, std::move(on_done)); + std::move(on_done)); } absl::Status AbstractAsyncHostToHostMemoryTransferManager::TransferRawDataToBuffer( int buffer_index, absl::string_view data, absl::AnyInvocable on_done) { - return TransferRawDataToSubBuffer( - buffer_index, data.data(), /*offset=*/0, data.size(), - /*is_last_transfer=*/true, std::move(on_done)); + return TransferRawDataToSubBuffer(buffer_index, data.data(), /*offset=*/0, + data.size(), std::move(on_done)); } // The definition events of `device_buffers_` must be ready before calling this @@ -976,20 +975,20 @@ AbstractAsyncHostToHostMemoryTransferManager::TransferRawDataToBuffer( absl::Status AbstractAsyncHostToHostMemoryTransferManager::TransferRawDataToSubBuffer( int buffer_index, const void* data, int64_t offset, int64_t transfer_size, - bool is_last_transfer, absl::AnyInvocable on_done) { + absl::AnyInvocable on_done) { return FillRawDataToSubBuffer( buffer_index, [offset, data, transfer_size](void* b, int64_t size) { std::memcpy(reinterpret_cast(b) + offset, data, transfer_size); }, - is_last_transfer, std::move(on_done)); + std::move(on_done)); } absl::Status AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer( int buffer_index, absl::AnyInvocable fill_fn, - bool is_last_transfer, absl::AnyInvocable on_done) { + absl::AnyInvocable on_done) { { // We release the lock when out of scope because // `async_work_runner_->Schedule` might sometimes run the closure in this @@ -1005,7 +1004,7 @@ AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer( CHECK(async_work_runner_ != nullptr); async_work_runner_->Schedule([this, fill_fn = std::move(fill_fn), - is_last_transfer, on_done = std::move(on_done), + on_done = std::move(on_done), buffer_index]() mutable -> void { tsl::RCReference event; { @@ -1013,9 +1012,6 @@ AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer( const auto& b = device_buffers_[buffer_index]->Buffers()[0]; CHECK(b.IsConcrete()); fill_fn(reinterpret_cast(b->data()), b->size()); - if (is_last_transfer) { - last_transfer_finished_[buffer_index] = true; - } --buffer_transfers_in_flight_[buffer_index]; --transfers_in_flight_; if (buffer_transfers_in_flight_[buffer_index] == 0 && @@ -1033,6 +1029,12 @@ AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer( return absl::OkStatus(); } +void AbstractAsyncHostToHostMemoryTransferManager::MarkBufferCompletion( + int buffer_index) { + absl::MutexLock l(&mu_); + last_transfer_finished_[buffer_index] = true; +} + void AbstractAsyncHostToHostMemoryTransferManager::SetBufferError( int buffer_index, absl::Status error) { absl::MutexLock l(&mu_); diff --git a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h index 3a8ddcd4729e20..74243d7dd24905 100644 --- a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h +++ b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h @@ -342,7 +342,9 @@ class AbstractAsyncHostToHostMemoryTransferManager absl::Status TransferRawDataToSubBuffer( int buffer_index, const void* data, int64_t offset, int64_t transfer_size, - bool is_last_transfer, absl::AnyInvocable on_done) override; + absl::AnyInvocable on_done) override; + + void MarkBufferCompletion(int buffer_index) override; void SetBufferError(int buffer_index, absl::Status error) override; @@ -373,7 +375,7 @@ class AbstractAsyncHostToHostMemoryTransferManager absl::Status FillRawDataToSubBuffer( int buffer_index, absl::AnyInvocable fill_fn, - bool is_last_transfer, absl::AnyInvocable on_done); + absl::AnyInvocable on_done); mutable absl::Mutex mu_; // The number of transfers that are currently in flight. diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index 03a5cb6eb2a49a..59fbaa72b879f5 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -776,13 +776,12 @@ class PjRtCApiAsyncHostToDeviceTransferManager int buffer_index, absl::string_view data, absl::AnyInvocable on_done) override { return TransferRawDataToSubBuffer(buffer_index, data.data(), 0, data.size(), - /*is_last_transfer=*/true, std::move(on_done)); } absl::Status TransferRawDataToSubBuffer( int buffer_index, const void* data, int64_t offset, int64_t transfer_size, - bool is_last_transfer, absl::AnyInvocable on_done) override { + absl::AnyInvocable on_done) override { PJRT_AsyncHostToDeviceTransferManager_TransferData_Args args; args.struct_size = PJRT_AsyncHostToDeviceTransferManager_TransferData_Args_STRUCT_SIZE; @@ -792,7 +791,6 @@ class PjRtCApiAsyncHostToDeviceTransferManager args.data = data; args.offset = offset; args.transfer_size = transfer_size; - args.is_last_transfer = is_last_transfer; const PJRT_Api* api = c_client_->pjrt_c_api(); RETURN_STATUS_IF_PJRT_ERROR( api->PJRT_AsyncHostToDeviceTransferManager_TransferData(&args), api); @@ -804,6 +802,8 @@ class PjRtCApiAsyncHostToDeviceTransferManager return absl::OkStatus(); } + void MarkBufferCompletion(int buffer_index) override {} + void SetBufferError(int buffer_index, absl::Status error) override { PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args args; args.struct_size = diff --git a/xla/pjrt/pjrt_client.h b/xla/pjrt/pjrt_client.h index e72c4069c13a96..07ed0aac40f5b5 100644 --- a/xla/pjrt/pjrt_client.h +++ b/xla/pjrt/pjrt_client.h @@ -732,17 +732,17 @@ class PjRtClient { // Transfers 'data' into a sub-buffer of buffer_index starting at offset, of // length transfer_size. 'data' must be already laid out in the correct // on-device format, for example returned by a call to - // buffer->CopyRawToHost. If is_last_transfer is false then the buffer - // remains unavailable to consumers after the transfer completes. If - // is_last_transfer is true then the buffer becomes available to consumers - // after the transfer completes, and no transfer calls (or SetBufferError - // calls) into buffer_index can be made after this call. on_done is called - // when the transfer is complete but before the buffers are made available - // to their consumers. 'data' must remain in scope until on_done is called. + // buffer->CopyRawToHost. on_done is called when the transfer is complete + // but before the buffers are made available to their consumers. + // 'data' must remain in scope until on_done is called. virtual absl::Status TransferRawDataToSubBuffer( int buffer_index, const void* data, int64_t offset, - int64_t transfer_size, bool is_last_transfer, - absl::AnyInvocable on_done) = 0; + int64_t transfer_size, absl::AnyInvocable on_done) = 0; + + // Indicates that data transfer for the buffer `buffer_index` is complete, + // thus the buffer becomes available to consumers. No transfer calls (or + // SetBufferError calls) into `buffer_index` can be made after this call. + virtual void MarkBufferCompletion(int buffer_index) = 0; // Indicates that a specific buffer should result in an error status. No // transfer calls (or further SetBufferError calls) into buffer_index can