diff --git a/xla/pjrt/BUILD b/xla/pjrt/BUILD index dea7e8b3ddb9e..5c7b07b87ccae 100644 --- a/xla/pjrt/BUILD +++ b/xla/pjrt/BUILD @@ -868,8 +868,6 @@ xla_cc_test( "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", "@stablehlo//:version", - "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test", ], ) diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index 669ba14094588..947c9b4044455 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -674,7 +674,6 @@ struct PJRT_AsyncHostToDeviceTransferManager_TransferData_Args { const void* data; int64_t offset; int64_t transfer_size; - bool is_last_transfer; PJRT_Event* done_with_h2d_transfer; // out }; PJRT_DEFINE_STRUCT_TRAITS( diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 4c6b6f0925652..7f1c6fd8a230f 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -662,7 +662,7 @@ PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData( PJRT_RETURN_IF_ERROR( args->transfer_manager->transfer_manager->TransferRawDataToSubBuffer( args->buffer_index, args->data, args->offset, args->transfer_size, - args->is_last_transfer, std::move(on_done_with_d2h_transfer))); + std::move(on_done_with_d2h_transfer))); args->done_with_h2d_transfer = new PJRT_Event{xla::PjRtFuture<>(std::move(promise))}; return nullptr; diff --git a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc index 5a468134f3142..78184511a4f9a 100644 --- a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc +++ b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc @@ -959,16 +959,15 @@ AbstractAsyncHostToHostMemoryTransferManager::TransferLiteralToBuffer( [literal](void* b, int64_t size) { PackOrCopy(literal.shape().element_type(), literal, b, size); }, - /*is_last_transfer=*/true, std::move(on_done)); + std::move(on_done)); } absl::Status AbstractAsyncHostToHostMemoryTransferManager::TransferRawDataToBuffer( int buffer_index, absl::string_view data, absl::AnyInvocable on_done) { - return TransferRawDataToSubBuffer( - buffer_index, data.data(), /*offset=*/0, data.size(), - /*is_last_transfer=*/true, std::move(on_done)); + return TransferRawDataToSubBuffer(buffer_index, data.data(), /*offset=*/0, + data.size(), std::move(on_done)); } // The definition events of `device_buffers_` must be ready before calling this @@ -976,20 +975,20 @@ AbstractAsyncHostToHostMemoryTransferManager::TransferRawDataToBuffer( absl::Status AbstractAsyncHostToHostMemoryTransferManager::TransferRawDataToSubBuffer( int buffer_index, const void* data, int64_t offset, int64_t transfer_size, - bool is_last_transfer, absl::AnyInvocable on_done) { + absl::AnyInvocable on_done) { return FillRawDataToSubBuffer( buffer_index, [offset, data, transfer_size](void* b, int64_t size) { std::memcpy(reinterpret_cast(b) + offset, data, transfer_size); }, - is_last_transfer, std::move(on_done)); + std::move(on_done)); } absl::Status AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer( int buffer_index, absl::AnyInvocable fill_fn, - bool is_last_transfer, absl::AnyInvocable on_done) { + absl::AnyInvocable on_done) { { // We release the lock when out of scope because // `async_work_runner_->Schedule` might sometimes run the closure in this @@ -1005,7 +1004,7 @@ AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer( CHECK(async_work_runner_ != nullptr); async_work_runner_->Schedule([this, fill_fn = std::move(fill_fn), - is_last_transfer, on_done = std::move(on_done), + on_done = std::move(on_done), buffer_index]() mutable -> void { tsl::RCReference event; { @@ -1013,9 +1012,6 @@ AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer( const auto& b = device_buffers_[buffer_index]->Buffers()[0]; CHECK(b.IsConcrete()); fill_fn(reinterpret_cast(b->data()), b->size()); - if (is_last_transfer) { - last_transfer_finished_[buffer_index] = true; - } --buffer_transfers_in_flight_[buffer_index]; --transfers_in_flight_; if (buffer_transfers_in_flight_[buffer_index] == 0 && @@ -1033,6 +1029,12 @@ AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer( return absl::OkStatus(); } +void AbstractAsyncHostToHostMemoryTransferManager::MarkBufferCompletion( + int buffer_index) { + absl::MutexLock l(&mu_); + last_transfer_finished_[buffer_index] = true; +} + void AbstractAsyncHostToHostMemoryTransferManager::SetBufferError( int buffer_index, absl::Status error) { absl::MutexLock l(&mu_); diff --git a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h index 3a8ddcd4729e2..74243d7dd2490 100644 --- a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h +++ b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h @@ -342,7 +342,9 @@ class AbstractAsyncHostToHostMemoryTransferManager absl::Status TransferRawDataToSubBuffer( int buffer_index, const void* data, int64_t offset, int64_t transfer_size, - bool is_last_transfer, absl::AnyInvocable on_done) override; + absl::AnyInvocable on_done) override; + + void MarkBufferCompletion(int buffer_index) override; void SetBufferError(int buffer_index, absl::Status error) override; @@ -373,7 +375,7 @@ class AbstractAsyncHostToHostMemoryTransferManager absl::Status FillRawDataToSubBuffer( int buffer_index, absl::AnyInvocable fill_fn, - bool is_last_transfer, absl::AnyInvocable on_done); + absl::AnyInvocable on_done); mutable absl::Mutex mu_; // The number of transfers that are currently in flight. diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 7375673e2f3df..25356d831b33b 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -304,8 +304,7 @@ class AsyncHostToDeviceTransferManager // Call cleanup once the transfer has finished on the stream. auto cleanup = [this, buffer_index, stream, on_done = std::move(on_done), event = std::move(event).value()]() mutable { - CleanUp(buffer_index, std::move(event), stream, - /*is_last_transfer=*/true, std::move(on_done)); + CleanUp(buffer_index, std::move(event), stream, std::move(on_done)); }; auto status = stream->DoHostCallback(std::move(cleanup)); if (!status.ok()) { @@ -325,13 +324,12 @@ class AsyncHostToDeviceTransferManager absl::AnyInvocable on_done) override { return TransferRawDataToSubBuffer(buffer_index, data.data(), /*offset=*/0, data.size(), - /*is_last_transfer=*/true, std::move(on_done)); } absl::Status TransferRawDataToSubBuffer( int buffer_index, const void* data, int64_t offset, int64_t transfer_size, - bool is_last_transfer, absl::AnyInvocable on_done) override { + absl::AnyInvocable on_done) override { auto* stream = device_->local_device_state()->host_to_device_stream(); auto* client = @@ -363,9 +361,6 @@ class AsyncHostToDeviceTransferManager "already been fully transferred", buffer_index); } - if (is_last_transfer) { - last_transfer_started_[buffer_index] = true; - } DCHECK(buffer_ptrs_[buffer_index]); if (buffer_ptrs_[buffer_index]->device_memory().empty()) { return InvalidArgument( @@ -418,14 +413,18 @@ class AsyncHostToDeviceTransferManager event.value()); auto cleanup = [this, buffer_index, event = std::move(event).value(), - stream, is_last_transfer, on_done = std::move(on_done), + stream, on_done = std::move(on_done), staging_buffer = std::move(staging_buffer)]() mutable { - CleanUp(buffer_index, std::move(event), stream, is_last_transfer, - std::move(on_done)); + CleanUp(buffer_index, std::move(event), stream, std::move(on_done)); }; return stream->DoHostCallback(std::move(cleanup)); } + void MarkBufferCompletion(int buffer_index) override { + absl::MutexLock l(&mu_); + last_transfer_started_[buffer_index] = true; + } + void SetBufferError(int buffer_index, absl::Status error) override { { absl::MutexLock l(&mu_); @@ -471,24 +470,11 @@ class AsyncHostToDeviceTransferManager PjRtStreamExecutorDevice* device_; // not owned. void CleanUp(int buffer_index, EventPool::Handle event, se::Stream* stream, - bool is_last_transfer, absl::AnyInvocable on_done) { + absl::AnyInvocable on_done) { { absl::MutexLock l(&mu_); - CHECK_GT(transfers_in_flight_, 0); --transfers_in_flight_; - if (is_last_transfer) { - // Drop our reference to the TrackedDeviceBuffer for this buffer. - CHECK(buffer_ptrs_[buffer_index]); - buffer_ptrs_[buffer_index] = nullptr; - CHECK_GT(remaining_buffer_count_, 0); - --remaining_buffer_count_; - definition_events_[buffer_index]->SetSequencingEvent(std::move(event), - stream); - if (remaining_buffer_count_ == 0) { - VLOG(1) << "TransferLiteralToBuffer for all buffers is done."; - } - } } // Call on_done after finishing all housekeeping and releasing the lock. diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index be05c632655bf..f3e938b9090ce 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -752,13 +752,12 @@ class PjRtCApiAsyncHostToDeviceTransferManager int buffer_index, absl::string_view data, absl::AnyInvocable on_done) override { return TransferRawDataToSubBuffer(buffer_index, data.data(), 0, data.size(), - /*is_last_transfer=*/true, std::move(on_done)); } absl::Status TransferRawDataToSubBuffer( int buffer_index, const void* data, int64_t offset, int64_t transfer_size, - bool is_last_transfer, absl::AnyInvocable on_done) override { + absl::AnyInvocable on_done) override { PJRT_AsyncHostToDeviceTransferManager_TransferData_Args args; args.struct_size = PJRT_AsyncHostToDeviceTransferManager_TransferData_Args_STRUCT_SIZE; @@ -768,7 +767,6 @@ class PjRtCApiAsyncHostToDeviceTransferManager args.data = data; args.offset = offset; args.transfer_size = transfer_size; - args.is_last_transfer = is_last_transfer; const PJRT_Api* api = c_client_->pjrt_c_api(); RETURN_STATUS_IF_PJRT_ERROR( api->PJRT_AsyncHostToDeviceTransferManager_TransferData(&args), api); @@ -780,6 +778,8 @@ class PjRtCApiAsyncHostToDeviceTransferManager return absl::OkStatus(); } + void MarkBufferCompletion(int buffer_index) override {} + void SetBufferError(int buffer_index, absl::Status error) override { PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args args; args.struct_size = diff --git a/xla/pjrt/pjrt_client.h b/xla/pjrt/pjrt_client.h index 659c403e25c52..a4e2b02766560 100644 --- a/xla/pjrt/pjrt_client.h +++ b/xla/pjrt/pjrt_client.h @@ -732,17 +732,17 @@ class PjRtClient { // Transfers 'data' into a sub-buffer of buffer_index starting at offset, of // length transfer_size. 'data' must be already laid out in the correct // on-device format, for example returned by a call to - // buffer->CopyRawToHost. If is_last_transfer is false then the buffer - // remains unavailable to consumers after the transfer completes. If - // is_last_transfer is true then the buffer becomes available to consumers - // after the transfer completes, and no transfer calls (or SetBufferError - // calls) into buffer_index can be made after this call. on_done is called - // when the transfer is complete but before the buffers are made available - // to their consumers. 'data' must remain in scope until on_done is called. + // buffer->CopyRawToHost. on_done is called when the transfer is complete + // but before the buffers are made available to their consumers. + // 'data' must remain in scope until on_done is called. virtual absl::Status TransferRawDataToSubBuffer( int buffer_index, const void* data, int64_t offset, - int64_t transfer_size, bool is_last_transfer, - absl::AnyInvocable on_done) = 0; + int64_t transfer_size, absl::AnyInvocable on_done) = 0; + + // Indicates that data transfer for the buffer `buffer_index` is complete, + // thus the buffer becomes available to consumers. No transfer calls (or + // SetBufferError calls) into `buffer_index` can be made after this call. + virtual void MarkBufferCompletion(int buffer_index) = 0; // Indicates that a specific buffer should result in an error status. No // transfer calls (or further SetBufferError calls) into buffer_index can