Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the is_last_transfer parameter #22197

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -868,8 +868,6 @@ xla_cc_test(
"@com_google_googletest//:gtest_main",
"@llvm-project//mlir:IR",
"@stablehlo//:version",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)

Expand Down
1 change: 0 additions & 1 deletion xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,6 @@ struct PJRT_AsyncHostToDeviceTransferManager_TransferData_Args {
const void* data;
int64_t offset;
int64_t transfer_size;
bool is_last_transfer;
PJRT_Event* done_with_h2d_transfer; // out
};
PJRT_DEFINE_STRUCT_TRAITS(
Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData(
PJRT_RETURN_IF_ERROR(
args->transfer_manager->transfer_manager->TransferRawDataToSubBuffer(
args->buffer_index, args->data, args->offset, args->transfer_size,
args->is_last_transfer, std::move(on_done_with_d2h_transfer)));
std::move(on_done_with_d2h_transfer)));
args->done_with_h2d_transfer =
new PJRT_Event{xla::PjRtFuture<>(std::move(promise))};
return nullptr;
Expand Down
24 changes: 13 additions & 11 deletions xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -959,37 +959,36 @@ AbstractAsyncHostToHostMemoryTransferManager::TransferLiteralToBuffer(
[literal](void* b, int64_t size) {
PackOrCopy(literal.shape().element_type(), literal, b, size);
},
/*is_last_transfer=*/true, std::move(on_done));
std::move(on_done));
}

absl::Status
AbstractAsyncHostToHostMemoryTransferManager::TransferRawDataToBuffer(
int buffer_index, absl::string_view data,
absl::AnyInvocable<void() &&> on_done) {
return TransferRawDataToSubBuffer(
buffer_index, data.data(), /*offset=*/0, data.size(),
/*is_last_transfer=*/true, std::move(on_done));
return TransferRawDataToSubBuffer(buffer_index, data.data(), /*offset=*/0,
data.size(), std::move(on_done));
}

// The definition events of `device_buffers_` must be ready before calling this
// function.
absl::Status
AbstractAsyncHostToHostMemoryTransferManager::TransferRawDataToSubBuffer(
int buffer_index, const void* data, int64_t offset, int64_t transfer_size,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) {
absl::AnyInvocable<void() &&> on_done) {
return FillRawDataToSubBuffer(
buffer_index,
[offset, data, transfer_size](void* b, int64_t size) {
std::memcpy(reinterpret_cast<char*>(b) + offset, data, transfer_size);
},
is_last_transfer, std::move(on_done));
std::move(on_done));
}

absl::Status
AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer(
int buffer_index,
absl::AnyInvocable<void(void* data, int64_t size)> fill_fn,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) {
absl::AnyInvocable<void() &&> on_done) {
{
// We release the lock when out of scope because
// `async_work_runner_->Schedule` might sometimes run the closure in this
Expand All @@ -1005,17 +1004,14 @@ AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer(

CHECK(async_work_runner_ != nullptr);
async_work_runner_->Schedule([this, fill_fn = std::move(fill_fn),
is_last_transfer, on_done = std::move(on_done),
on_done = std::move(on_done),
buffer_index]() mutable -> void {
tsl::RCReference<tsl::AsyncValue> event;
{
absl::MutexLock l(&mu_);
const auto& b = device_buffers_[buffer_index]->Buffers()[0];
CHECK(b.IsConcrete());
fill_fn(reinterpret_cast<char*>(b->data()), b->size());
if (is_last_transfer) {
last_transfer_finished_[buffer_index] = true;
}
--buffer_transfers_in_flight_[buffer_index];
--transfers_in_flight_;
if (buffer_transfers_in_flight_[buffer_index] == 0 &&
Expand All @@ -1033,6 +1029,12 @@ AbstractAsyncHostToHostMemoryTransferManager::FillRawDataToSubBuffer(
return absl::OkStatus();
}

void AbstractAsyncHostToHostMemoryTransferManager::MarkBufferCompletion(
int buffer_index) {
absl::MutexLock l(&mu_);
last_transfer_finished_[buffer_index] = true;
}

void AbstractAsyncHostToHostMemoryTransferManager::SetBufferError(
int buffer_index, absl::Status error) {
absl::MutexLock l(&mu_);
Expand Down
6 changes: 4 additions & 2 deletions xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,9 @@ class AbstractAsyncHostToHostMemoryTransferManager

absl::Status TransferRawDataToSubBuffer(
int buffer_index, const void* data, int64_t offset, int64_t transfer_size,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) override;
absl::AnyInvocable<void() &&> on_done) override;

void MarkBufferCompletion(int buffer_index) override;

void SetBufferError(int buffer_index, absl::Status error) override;

Expand Down Expand Up @@ -373,7 +375,7 @@ class AbstractAsyncHostToHostMemoryTransferManager
absl::Status FillRawDataToSubBuffer(
int buffer_index,
absl::AnyInvocable<void(void* data, int64_t size)> fill_fn,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done);
absl::AnyInvocable<void() &&> on_done);

mutable absl::Mutex mu_;
// The number of transfers that are currently in flight.
Expand Down
34 changes: 10 additions & 24 deletions xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,7 @@ class AsyncHostToDeviceTransferManager
// Call cleanup once the transfer has finished on the stream.
auto cleanup = [this, buffer_index, stream, on_done = std::move(on_done),
event = std::move(event).value()]() mutable {
CleanUp(buffer_index, std::move(event), stream,
/*is_last_transfer=*/true, std::move(on_done));
CleanUp(buffer_index, std::move(event), stream, std::move(on_done));
};
auto status = stream->DoHostCallback(std::move(cleanup));
if (!status.ok()) {
Expand All @@ -325,13 +324,12 @@ class AsyncHostToDeviceTransferManager
absl::AnyInvocable<void() &&> on_done) override {
return TransferRawDataToSubBuffer(buffer_index, data.data(),
/*offset=*/0, data.size(),
/*is_last_transfer=*/true,
std::move(on_done));
}

absl::Status TransferRawDataToSubBuffer(
int buffer_index, const void* data, int64_t offset, int64_t transfer_size,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) override {
absl::AnyInvocable<void() &&> on_done) override {
auto* stream = device_->local_device_state()->host_to_device_stream();

auto* client =
Expand Down Expand Up @@ -363,9 +361,6 @@ class AsyncHostToDeviceTransferManager
"already been fully transferred",
buffer_index);
}
if (is_last_transfer) {
last_transfer_started_[buffer_index] = true;
}
DCHECK(buffer_ptrs_[buffer_index]);
if (buffer_ptrs_[buffer_index]->device_memory().empty()) {
return InvalidArgument(
Expand Down Expand Up @@ -418,14 +413,18 @@ class AsyncHostToDeviceTransferManager
event.value());

auto cleanup = [this, buffer_index, event = std::move(event).value(),
stream, is_last_transfer, on_done = std::move(on_done),
stream, on_done = std::move(on_done),
staging_buffer = std::move(staging_buffer)]() mutable {
CleanUp(buffer_index, std::move(event), stream, is_last_transfer,
std::move(on_done));
CleanUp(buffer_index, std::move(event), stream, std::move(on_done));
};
return stream->DoHostCallback(std::move(cleanup));
}

void MarkBufferCompletion(int buffer_index) override {
absl::MutexLock l(&mu_);
last_transfer_started_[buffer_index] = true;
}

void SetBufferError(int buffer_index, absl::Status error) override {
{
absl::MutexLock l(&mu_);
Expand Down Expand Up @@ -471,24 +470,11 @@ class AsyncHostToDeviceTransferManager
PjRtStreamExecutorDevice* device_; // not owned.

void CleanUp(int buffer_index, EventPool::Handle event, se::Stream* stream,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) {
absl::AnyInvocable<void() &&> on_done) {
{
absl::MutexLock l(&mu_);

CHECK_GT(transfers_in_flight_, 0);
--transfers_in_flight_;
if (is_last_transfer) {
// Drop our reference to the TrackedDeviceBuffer for this buffer.
CHECK(buffer_ptrs_[buffer_index]);
buffer_ptrs_[buffer_index] = nullptr;
CHECK_GT(remaining_buffer_count_, 0);
--remaining_buffer_count_;
definition_events_[buffer_index]->SetSequencingEvent(std::move(event),
stream);
if (remaining_buffer_count_ == 0) {
VLOG(1) << "TransferLiteralToBuffer for all buffers is done.";
}
}
}

// Call on_done after finishing all housekeeping and releasing the lock.
Expand Down
6 changes: 3 additions & 3 deletions xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -752,13 +752,12 @@ class PjRtCApiAsyncHostToDeviceTransferManager
int buffer_index, absl::string_view data,
absl::AnyInvocable<void() &&> on_done) override {
return TransferRawDataToSubBuffer(buffer_index, data.data(), 0, data.size(),
/*is_last_transfer=*/true,
std::move(on_done));
}

absl::Status TransferRawDataToSubBuffer(
int buffer_index, const void* data, int64_t offset, int64_t transfer_size,
bool is_last_transfer, absl::AnyInvocable<void() &&> on_done) override {
absl::AnyInvocable<void() &&> on_done) override {
PJRT_AsyncHostToDeviceTransferManager_TransferData_Args args;
args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_TransferData_Args_STRUCT_SIZE;
Expand All @@ -768,7 +767,6 @@ class PjRtCApiAsyncHostToDeviceTransferManager
args.data = data;
args.offset = offset;
args.transfer_size = transfer_size;
args.is_last_transfer = is_last_transfer;
const PJRT_Api* api = c_client_->pjrt_c_api();
RETURN_STATUS_IF_PJRT_ERROR(
api->PJRT_AsyncHostToDeviceTransferManager_TransferData(&args), api);
Expand All @@ -780,6 +778,8 @@ class PjRtCApiAsyncHostToDeviceTransferManager
return absl::OkStatus();
}

void MarkBufferCompletion(int buffer_index) override {}

void SetBufferError(int buffer_index, absl::Status error) override {
PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args args;
args.struct_size =
Expand Down
18 changes: 9 additions & 9 deletions xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -732,17 +732,17 @@ class PjRtClient {
// Transfers 'data' into a sub-buffer of buffer_index starting at offset, of
// length transfer_size. 'data' must be already laid out in the correct
// on-device format, for example returned by a call to
// buffer->CopyRawToHost. If is_last_transfer is false then the buffer
// remains unavailable to consumers after the transfer completes. If
// is_last_transfer is true then the buffer becomes available to consumers
// after the transfer completes, and no transfer calls (or SetBufferError
// calls) into buffer_index can be made after this call. on_done is called
// when the transfer is complete but before the buffers are made available
// to their consumers. 'data' must remain in scope until on_done is called.
// buffer->CopyRawToHost. on_done is called when the transfer is complete
// but before the buffers are made available to their consumers.
// 'data' must remain in scope until on_done is called.
virtual absl::Status TransferRawDataToSubBuffer(
int buffer_index, const void* data, int64_t offset,
int64_t transfer_size, bool is_last_transfer,
absl::AnyInvocable<void() &&> on_done) = 0;
int64_t transfer_size, absl::AnyInvocable<void() &&> on_done) = 0;

// Indicates that data transfer for the buffer `buffer_index` is complete,
// thus the buffer becomes available to consumers. No transfer calls (or
// SetBufferError calls) into `buffer_index` can be made after this call.
virtual void MarkBufferCompletion(int buffer_index) = 0;

// Indicates that a specific buffer should result in an error status. No
// transfer calls (or further SetBufferError calls) into buffer_index can
Expand Down
Loading