From 3f0b6c9928392c19cff618def88cacc6231065e1 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Thu, 6 Feb 2025 15:29:31 -0800 Subject: [PATCH] Plumb layout through the creation of PjRtArrays. PiperOrigin-RevId: 724096815 --- xla/python/pjrt_ifrt/pjrt_array.cc | 65 +++++++++++++++-------- xla/python/pjrt_ifrt/pjrt_array.h | 13 +++-- xla/python/pjrt_ifrt/pjrt_client.cc | 20 ++++++- xla/python/pjrt_ifrt/pjrt_executable.cc | 22 ++++++-- xla/python/pjrt_ifrt/pjrt_remap.cc | 3 +- xla/python/transfer/py_socket_transfer.cc | 2 +- 6 files changed, 92 insertions(+), 33 deletions(-) diff --git a/xla/python/pjrt_ifrt/pjrt_array.cc b/xla/python/pjrt_ifrt/pjrt_array.cc index b8ba14293830cb..cf8754c4d666e3 100644 --- a/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/xla/python/pjrt_ifrt/pjrt_array.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" +#include "xla/layout.h" #include "xla/literal.h" #include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/pjrt_client.h" @@ -144,18 +145,22 @@ MemoryKind MakeMemoryKindFromPjRtBuffer(PjRtBuffer* pjrt_buffer) { absl::StatusOr> PjRtArray::Create( PjRtCompatibleClient* client, DType dtype, Shape shape, - std::shared_ptr sharding, PjRtBuffers pjrt_buffers) { + std::shared_ptr sharding, PjRtBuffers pjrt_buffers, + std::shared_ptr layout) { TF_RETURN_IF_ERROR(ValidateArrayCreationInput(sharding, pjrt_buffers)); return tsl::MakeRef(client, dtype, std::move(shape), - std::move(sharding), std::move(pjrt_buffers)); + std::move(sharding), std::move(pjrt_buffers), + std::move(layout)); } absl::StatusOr> PjRtArray::Create( PjRtCompatibleClient* client, DType dtype, DynamicShape dynamic_shape, - std::shared_ptr sharding, PjRtBuffers pjrt_buffers) { + std::shared_ptr sharding, PjRtBuffers pjrt_buffers, + std::shared_ptr layout) { TF_RETURN_IF_ERROR(ValidateArrayCreationInput(sharding, pjrt_buffers)); return tsl::MakeRef(client, dtype, std::move(dynamic_shape), - std::move(sharding), std::move(pjrt_buffers)); + std::move(sharding), std::move(pjrt_buffers), + std::move(layout)); } absl::StatusOr> PjRtArray::Create( @@ -166,9 +171,15 @@ absl::StatusOr> PjRtArray::Create( client->LookupPjRtDevice(pjrt_buffer->device())); auto sharding = SingleDeviceSharding::Create( device, MakeMemoryKindFromPjRtBuffer(pjrt_buffer.get())); - return tsl::MakeRef(client, dtype, std::move(shape), - std::move(sharding), - PjRtBuffers({std::move(pjrt_buffer)})); + std::shared_ptr layout; + if (pjrt_buffer->on_device_shape().has_layout()) { + layout = pjrt_buffer->layout(); + } else { + layout = std::make_shared(xla::Layout()); + } + return tsl::MakeRef( + client, dtype, std::move(shape), std::move(sharding), + PjRtBuffers({std::move(pjrt_buffer)}), std::move(layout)); } absl::StatusOr> PjRtArray::FullyReplicatedShard( @@ -214,8 +225,12 @@ absl::StatusOr> PjRtArray::Create( BasicDeviceList::Create(std::move(devices)), memory_kind, /*shape=*/shape, /*shard_shapes=*/shapes); + if (pjrt_buffers.empty()) { + return InvalidArgument("PjRtBuffers must be non-empty."); + } + auto layout = pjrt_buffers.front()->layout(); return PjRtArray::Create(client, dtype, std::move(shape), std::move(sharding), - std::move(pjrt_buffers)); + std::move(pjrt_buffers), std::move(layout)); } absl::StatusOr> PjRtArray::Create( @@ -247,28 +262,37 @@ absl::StatusOr> PjRtArray::Create( BasicDeviceList::Create(std::move(devices)), memory_kind, /*dynamic_shape=*/dynamic_shape, /*shard_dynamic_shapes=*/dynamic_shapes); + if (pjrt_buffers.empty()) { + return InvalidArgument("PjRtBuffers must be non-empty."); + } + auto layout = pjrt_buffers.front()->layout(); return PjRtArray::Create(client, dtype, std::move(dynamic_shape), - std::move(sharding), std::move(pjrt_buffers)); + std::move(sharding), std::move(pjrt_buffers), + std::move(layout)); } PjRtArray::PjRtArray(PjRtCompatibleClient* client, DType dtype, Shape shape, std::shared_ptr sharding, - PjRtBuffers pjrt_buffers) + PjRtBuffers pjrt_buffers, + std::shared_ptr layout) : client_(client), dtype_(dtype), shape_(std::move(shape)), sharding_(std::move(sharding)), - pjrt_buffers_(std::move(pjrt_buffers)) {} + pjrt_buffers_(std::move(pjrt_buffers)), + layout_(std::move(layout)) {} PjRtArray::PjRtArray(PjRtCompatibleClient* client, DType dtype, DynamicShape dynamic_shape, std::shared_ptr sharding, - PjRtBuffers pjrt_buffers) + PjRtBuffers pjrt_buffers, + std::shared_ptr layout) : client_(client), dtype_(dtype), shape_(std::move(dynamic_shape)), sharding_(std::move(sharding)), - pjrt_buffers_(std::move(pjrt_buffers)) {} + pjrt_buffers_(std::move(pjrt_buffers)), + layout_(std::move(layout)) {} absl::StatusOr>> PjRtArray::DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) { @@ -306,7 +330,7 @@ PjRtArray::DisassembleIntoSingleDeviceArrays( PjRtArray::Create(client_, dtype_, std::move(shape_and_shardings[i].first), std::move(shape_and_shardings[i].second), - std::move(buffers))); + std::move(buffers), layout_)); result.push_back(std::move(array)); } return absl::OkStatus(); @@ -507,7 +531,8 @@ absl::StatusOr> PjRtArray::Copy( return std::visit( [this, &new_sharding, &buffers](const auto& shape) { return PjRtArray::Create(client_, dtype_, shape, - std::move(new_sharding), std::move(buffers)); + std::move(new_sharding), std::move(buffers), + layout_); }, shape_); } @@ -554,21 +579,17 @@ std::string PjRtArray::DebugString() const { sharding_->DebugString(), layout_str); } -// TODO(b/330198879): populate layout at construction instead of accessing PJRT -// buffer directly for consistency with Pathways. absl::StatusOr> PjRtArray::layout() const { - CHECK(!pjrt_buffers_.empty()); - std::shared_ptr layout = pjrt_buffers_[0]->layout(); #ifndef NDEBUG for (int i = 1; i < pjrt_buffers_.size(); ++i) { std::shared_ptr layout_i = pjrt_buffers_[i]->layout(); - DCHECK(*layout == *layout_i) + DCHECK(*layout_ == *layout_i) << "PjRtArray has mismatched layouts across shards! " - << "shard 0: " << layout->ToString() << ", shard " << i << ": " + << "shard 0: " << layout_->ToString() << ", shard " << i << ": " << layout_i->ToString(); } #endif - return layout; + return layout_; } } // namespace ifrt diff --git a/xla/python/pjrt_ifrt/pjrt_array.h b/xla/python/pjrt_ifrt/pjrt_array.h index 7a88f708248393..eca91242c42a05 100644 --- a/xla/python/pjrt_ifrt/pjrt_array.h +++ b/xla/python/pjrt_ifrt/pjrt_array.h @@ -70,12 +70,14 @@ class PjRtArray final // General array construction (with static shape). static absl::StatusOr> Create( PjRtCompatibleClient* client, DType dtype, Shape shape, - std::shared_ptr sharding, PjRtBuffers pjrt_buffers); + std::shared_ptr sharding, PjRtBuffers pjrt_buffers, + std::shared_ptr layout); // General array construction (with dynamic shape). static absl::StatusOr> Create( PjRtCompatibleClient* client, DType dtype, DynamicShape dynamic_shape, - std::shared_ptr sharding, PjRtBuffers pjrt_buffers); + std::shared_ptr sharding, PjRtBuffers pjrt_buffers, + std::shared_ptr layout); // Shorthand for a single-shard array construction. static absl::StatusOr> Create( @@ -184,11 +186,13 @@ class PjRtArray final private: PjRtArray(PjRtCompatibleClient* client, DType dtype, Shape shape, - std::shared_ptr sharding, PjRtBuffers pjrt_buffers); + std::shared_ptr sharding, PjRtBuffers pjrt_buffers, + std::shared_ptr layout); PjRtArray(PjRtCompatibleClient* client, DType dtype, DynamicShape dynamic_shape, - std::shared_ptr sharding, PjRtBuffers pjrt_buffers); + std::shared_ptr sharding, PjRtBuffers pjrt_buffers, + std::shared_ptr layout); template friend tsl::RCReference tsl::MakeRef(Args&&... args); @@ -198,6 +202,7 @@ class PjRtArray final std::variant shape_; std::shared_ptr sharding_; PjRtBuffers pjrt_buffers_; + std::shared_ptr layout_; }; } // namespace ifrt diff --git a/xla/python/pjrt_ifrt/pjrt_client.cc b/xla/python/pjrt_ifrt/pjrt_client.cc index 745265daf2d568..57bf7187510bb2 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/xla/python/pjrt_ifrt/pjrt_client.cc @@ -964,8 +964,12 @@ absl::StatusOr> PjRtClient::MakeArrayFromHostBuffer( } buffers.push_back(std::move(buffer)); } + if (buffers.empty()) { + return InvalidArgument("Buffers must be non-empty."); + } + auto layout = buffers.front()->layout(); return PjRtArray::Create(this, dtype, std::move(shape), std::move(sharding), - std::move(buffers)); + std::move(buffers), std::move(layout)); } absl::StatusOr> @@ -1070,8 +1074,20 @@ PjRtClient::AssembleArrayFromSingleDeviceArrays( break; } } + std::shared_ptr layout; + // DO NOT SUBMIT: Should this check nullptr instead? + if (!buffers.empty() && buffers.front()->on_device_shape().has_layout()) { + layout = buffers.front()->layout(); + } else if (dtype.kind() == DType::kToken) { + layout = nullptr; + } else { + TF_ASSIGN_OR_RETURN(layout, + GetDefaultLayout(dtype, shape.dims(), + sharding->devices()->devices().front(), + sharding->memory_kind())); + } return PjRtArray::Create(this, dtype, std::move(shape), std::move(sharding), - std::move(buffers)); + std::move(buffers), std::move(layout)); } absl::StatusOr>> PjRtClient::CopyArrays( diff --git a/xla/python/pjrt_ifrt/pjrt_executable.cc b/xla/python/pjrt_ifrt/pjrt_executable.cc index 3b9ccbdfebee2a..a435cc9925f445 100644 --- a/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -33,11 +33,13 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/layout.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/primitive_util.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" @@ -658,6 +660,20 @@ PjRtLoadedExecutable::Execute( // memory_kind shares the same Sharding object. absl::flat_hash_map> single_device_shardings; + auto maybe_layouts = GetOutputLayouts(); + std::vector> layouts; + if (absl::IsUnimplemented(maybe_layouts.status())) { + layouts.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + layouts.push_back(nullptr); + // DO NOT SUBMIT: Should this be nullopt instead? + // layouts.push_back(std::make_shared(xla::Layout())); + } + } else { + TF_RETURN_IF_ERROR(maybe_layouts.status()); + layouts = *std::move(maybe_layouts); + } + for (int i = 0; i < num_outputs; ++i) { PjRtArray::PjRtBuffers buffers; buffers.reserve(num_computations); @@ -698,9 +714,9 @@ PjRtLoadedExecutable::Execute( } else { sharding = output_shardings_[i]; } - outputs.push_back(*PjRtArray::Create(client_, output_dtypes_[i], - output_shapes_[i], std::move(sharding), - std::move(buffers))); + outputs.push_back(*PjRtArray::Create( + client_, output_dtypes_[i], output_shapes_[i], std::move(sharding), + std::move(buffers), std::move(layouts[i]))); } ExecuteResult result; diff --git a/xla/python/pjrt_ifrt/pjrt_remap.cc b/xla/python/pjrt_ifrt/pjrt_remap.cc index ff9925f0a61574..0cfa241e4fba71 100644 --- a/xla/python/pjrt_ifrt/pjrt_remap.cc +++ b/xla/python/pjrt_ifrt/pjrt_remap.cc @@ -134,7 +134,8 @@ PjRtCompatibleClientRemapArrays( PjRtArray::Create(client, plan.output_specs[i].dtype, plan.output_specs[i].shape, plan.output_specs[i].sharding, - std::move(out_buffers_list[i]))); + std::move(out_buffers_list[i]), + plan.output_specs[i].layout)); output_arrays.push_back(std::move(output_array)); } return output_arrays; diff --git a/xla/python/transfer/py_socket_transfer.cc b/xla/python/transfer/py_socket_transfer.cc index 64c2322947da69..6267e892a11ac6 100644 --- a/xla/python/transfer/py_socket_transfer.cc +++ b/xla/python/transfer/py_socket_transfer.cc @@ -356,7 +356,7 @@ void RegisterTransferServerTypes(nanobind::module_& m) { } auto arr = xla::ValueOrThrow(xla::ifrt::PjRtArray::Create( ifrt_client, avals[i].dtype, avals[i].shape, avals[i].sharding, - std::move(buffers))); + std::move(buffers), avals[i].layout)); out.push_back(xla::PyArray::MakeFromIfrtArrayAndSharding( py_client, traceback, std::move(arr), shardings[i], false, true, /*skip_checks=*/false));