Skip to content

Commit

Permalink
Plumb layout through the creation of PjRtArrays.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 724096815
  • Loading branch information
emilyfertig authored and Google-ML-Automation committed Feb 13, 2025
1 parent 1c73e14 commit 3f0b6c9
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 33 deletions.
65 changes: 43 additions & 22 deletions xla/python/pjrt_ifrt/pjrt_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/pjrt/host_memory_spaces.h"
#include "xla/pjrt/pjrt_client.h"
Expand Down Expand Up @@ -144,18 +145,22 @@ MemoryKind MakeMemoryKindFromPjRtBuffer(PjRtBuffer* pjrt_buffer) {

absl::StatusOr<tsl::RCReference<PjRtArray>> PjRtArray::Create(
PjRtCompatibleClient* client, DType dtype, Shape shape,
std::shared_ptr<const Sharding> sharding, PjRtBuffers pjrt_buffers) {
std::shared_ptr<const Sharding> sharding, PjRtBuffers pjrt_buffers,
std::shared_ptr<const PjRtLayout> layout) {
TF_RETURN_IF_ERROR(ValidateArrayCreationInput(sharding, pjrt_buffers));
return tsl::MakeRef<PjRtArray>(client, dtype, std::move(shape),
std::move(sharding), std::move(pjrt_buffers));
std::move(sharding), std::move(pjrt_buffers),
std::move(layout));
}

absl::StatusOr<tsl::RCReference<PjRtArray>> PjRtArray::Create(
PjRtCompatibleClient* client, DType dtype, DynamicShape dynamic_shape,
std::shared_ptr<const Sharding> sharding, PjRtBuffers pjrt_buffers) {
std::shared_ptr<const Sharding> sharding, PjRtBuffers pjrt_buffers,
std::shared_ptr<const PjRtLayout> layout) {
TF_RETURN_IF_ERROR(ValidateArrayCreationInput(sharding, pjrt_buffers));
return tsl::MakeRef<PjRtArray>(client, dtype, std::move(dynamic_shape),
std::move(sharding), std::move(pjrt_buffers));
std::move(sharding), std::move(pjrt_buffers),
std::move(layout));
}

absl::StatusOr<tsl::RCReference<PjRtArray>> PjRtArray::Create(
Expand All @@ -166,9 +171,15 @@ absl::StatusOr<tsl::RCReference<PjRtArray>> PjRtArray::Create(
client->LookupPjRtDevice(pjrt_buffer->device()));
auto sharding = SingleDeviceSharding::Create(
device, MakeMemoryKindFromPjRtBuffer(pjrt_buffer.get()));
return tsl::MakeRef<PjRtArray>(client, dtype, std::move(shape),
std::move(sharding),
PjRtBuffers({std::move(pjrt_buffer)}));
std::shared_ptr<const PjRtLayout> layout;
if (pjrt_buffer->on_device_shape().has_layout()) {
layout = pjrt_buffer->layout();
} else {
layout = std::make_shared<xla::PjRtLayout>(xla::Layout());
}
return tsl::MakeRef<PjRtArray>(
client, dtype, std::move(shape), std::move(sharding),
PjRtBuffers({std::move(pjrt_buffer)}), std::move(layout));
}

absl::StatusOr<tsl::RCReference<Array>> PjRtArray::FullyReplicatedShard(
Expand Down Expand Up @@ -214,8 +225,12 @@ absl::StatusOr<tsl::RCReference<PjRtArray>> PjRtArray::Create(
BasicDeviceList::Create(std::move(devices)), memory_kind,
/*shape=*/shape,
/*shard_shapes=*/shapes);
if (pjrt_buffers.empty()) {
return InvalidArgument("PjRtBuffers must be non-empty.");
}
auto layout = pjrt_buffers.front()->layout();
return PjRtArray::Create(client, dtype, std::move(shape), std::move(sharding),
std::move(pjrt_buffers));
std::move(pjrt_buffers), std::move(layout));
}

absl::StatusOr<tsl::RCReference<PjRtArray>> PjRtArray::Create(
Expand Down Expand Up @@ -247,28 +262,37 @@ absl::StatusOr<tsl::RCReference<PjRtArray>> PjRtArray::Create(
BasicDeviceList::Create(std::move(devices)), memory_kind,
/*dynamic_shape=*/dynamic_shape,
/*shard_dynamic_shapes=*/dynamic_shapes);
if (pjrt_buffers.empty()) {
return InvalidArgument("PjRtBuffers must be non-empty.");
}
auto layout = pjrt_buffers.front()->layout();
return PjRtArray::Create(client, dtype, std::move(dynamic_shape),
std::move(sharding), std::move(pjrt_buffers));
std::move(sharding), std::move(pjrt_buffers),
std::move(layout));
}

PjRtArray::PjRtArray(PjRtCompatibleClient* client, DType dtype, Shape shape,
std::shared_ptr<const Sharding> sharding,
PjRtBuffers pjrt_buffers)
PjRtBuffers pjrt_buffers,
std::shared_ptr<const PjRtLayout> layout)
: client_(client),
dtype_(dtype),
shape_(std::move(shape)),
sharding_(std::move(sharding)),
pjrt_buffers_(std::move(pjrt_buffers)) {}
pjrt_buffers_(std::move(pjrt_buffers)),
layout_(std::move(layout)) {}

PjRtArray::PjRtArray(PjRtCompatibleClient* client, DType dtype,
DynamicShape dynamic_shape,
std::shared_ptr<const Sharding> sharding,
PjRtBuffers pjrt_buffers)
PjRtBuffers pjrt_buffers,
std::shared_ptr<const PjRtLayout> layout)
: client_(client),
dtype_(dtype),
shape_(std::move(dynamic_shape)),
sharding_(std::move(sharding)),
pjrt_buffers_(std::move(pjrt_buffers)) {}
pjrt_buffers_(std::move(pjrt_buffers)),
layout_(std::move(layout)) {}

absl::StatusOr<std::vector<tsl::RCReference<Array>>>
PjRtArray::DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) {
Expand Down Expand Up @@ -306,7 +330,7 @@ PjRtArray::DisassembleIntoSingleDeviceArrays(
PjRtArray::Create(client_, dtype_,
std::move(shape_and_shardings[i].first),
std::move(shape_and_shardings[i].second),
std::move(buffers)));
std::move(buffers), layout_));
result.push_back(std::move(array));
}
return absl::OkStatus();
Expand Down Expand Up @@ -507,7 +531,8 @@ absl::StatusOr<tsl::RCReference<Array>> PjRtArray::Copy(
return std::visit(
[this, &new_sharding, &buffers](const auto& shape) {
return PjRtArray::Create(client_, dtype_, shape,
std::move(new_sharding), std::move(buffers));
std::move(new_sharding), std::move(buffers),
layout_);
},
shape_);
}
Expand Down Expand Up @@ -554,21 +579,17 @@ std::string PjRtArray::DebugString() const {
sharding_->DebugString(), layout_str);
}

// TODO(b/330198879): populate layout at construction instead of accessing PJRT
// buffer directly for consistency with Pathways.
absl::StatusOr<std::shared_ptr<const PjRtLayout>> PjRtArray::layout() const {
CHECK(!pjrt_buffers_.empty());
std::shared_ptr<const PjRtLayout> layout = pjrt_buffers_[0]->layout();
#ifndef NDEBUG
for (int i = 1; i < pjrt_buffers_.size(); ++i) {
std::shared_ptr<const PjRtLayout> layout_i = pjrt_buffers_[i]->layout();
DCHECK(*layout == *layout_i)
DCHECK(*layout_ == *layout_i)
<< "PjRtArray has mismatched layouts across shards! "
<< "shard 0: " << layout->ToString() << ", shard " << i << ": "
<< "shard 0: " << layout_->ToString() << ", shard " << i << ": "
<< layout_i->ToString();
}
#endif
return layout;
return layout_;
}

} // namespace ifrt
Expand Down
13 changes: 9 additions & 4 deletions xla/python/pjrt_ifrt/pjrt_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ class PjRtArray final
// General array construction (with static shape).
static absl::StatusOr<tsl::RCReference<PjRtArray>> Create(
PjRtCompatibleClient* client, DType dtype, Shape shape,
std::shared_ptr<const Sharding> sharding, PjRtBuffers pjrt_buffers);
std::shared_ptr<const Sharding> sharding, PjRtBuffers pjrt_buffers,
std::shared_ptr<const PjRtLayout> layout);

// General array construction (with dynamic shape).
static absl::StatusOr<tsl::RCReference<PjRtArray>> Create(
PjRtCompatibleClient* client, DType dtype, DynamicShape dynamic_shape,
std::shared_ptr<const Sharding> sharding, PjRtBuffers pjrt_buffers);
std::shared_ptr<const Sharding> sharding, PjRtBuffers pjrt_buffers,
std::shared_ptr<const PjRtLayout> layout);

// Shorthand for a single-shard array construction.
static absl::StatusOr<tsl::RCReference<PjRtArray>> Create(
Expand Down Expand Up @@ -184,11 +186,13 @@ class PjRtArray final

private:
PjRtArray(PjRtCompatibleClient* client, DType dtype, Shape shape,
std::shared_ptr<const Sharding> sharding, PjRtBuffers pjrt_buffers);
std::shared_ptr<const Sharding> sharding, PjRtBuffers pjrt_buffers,
std::shared_ptr<const PjRtLayout> layout);

PjRtArray(PjRtCompatibleClient* client, DType dtype,
DynamicShape dynamic_shape,
std::shared_ptr<const Sharding> sharding, PjRtBuffers pjrt_buffers);
std::shared_ptr<const Sharding> sharding, PjRtBuffers pjrt_buffers,
std::shared_ptr<const PjRtLayout> layout);

template <typename T, typename... Args>
friend tsl::RCReference<T> tsl::MakeRef(Args&&... args);
Expand All @@ -198,6 +202,7 @@ class PjRtArray final
std::variant<Shape, DynamicShape> shape_;
std::shared_ptr<const Sharding> sharding_;
PjRtBuffers pjrt_buffers_;
std::shared_ptr<const PjRtLayout> layout_;
};

} // namespace ifrt
Expand Down
20 changes: 18 additions & 2 deletions xla/python/pjrt_ifrt/pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -964,8 +964,12 @@ absl::StatusOr<tsl::RCReference<Array>> PjRtClient::MakeArrayFromHostBuffer(
}
buffers.push_back(std::move(buffer));
}
if (buffers.empty()) {
return InvalidArgument("Buffers must be non-empty.");
}
auto layout = buffers.front()->layout();
return PjRtArray::Create(this, dtype, std::move(shape), std::move(sharding),
std::move(buffers));
std::move(buffers), std::move(layout));
}

absl::StatusOr<tsl::RCReference<Array>>
Expand Down Expand Up @@ -1070,8 +1074,20 @@ PjRtClient::AssembleArrayFromSingleDeviceArrays(
break;
}
}
std::shared_ptr<const PjRtLayout> layout;
// DO NOT SUBMIT: Should this check nullptr instead?
if (!buffers.empty() && buffers.front()->on_device_shape().has_layout()) {
layout = buffers.front()->layout();
} else if (dtype.kind() == DType::kToken) {
layout = nullptr;
} else {
TF_ASSIGN_OR_RETURN(layout,
GetDefaultLayout(dtype, shape.dims(),
sharding->devices()->devices().front(),
sharding->memory_kind()));
}
return PjRtArray::Create(this, dtype, std::move(shape), std::move(sharding),
std::move(buffers));
std::move(buffers), std::move(layout));
}

absl::StatusOr<std::vector<tsl::RCReference<Array>>> PjRtClient::CopyArrays(
Expand Down
22 changes: 19 additions & 3 deletions xla/python/pjrt_ifrt/pjrt_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ limitations under the License.
#include "mlir/IR/BuiltinOps.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h"
#include "xla/layout.h"
#include "xla/pjrt/host_callback.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/primitive_util.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/device.h"
Expand Down Expand Up @@ -658,6 +660,20 @@ PjRtLoadedExecutable::Execute(
// memory_kind shares the same Sharding object.
absl::flat_hash_map<MemoryKind, std::shared_ptr<const Sharding>>
single_device_shardings;
auto maybe_layouts = GetOutputLayouts();
std::vector<std::shared_ptr<const xla::PjRtLayout>> layouts;
if (absl::IsUnimplemented(maybe_layouts.status())) {
layouts.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
layouts.push_back(nullptr);
// DO NOT SUBMIT: Should this be nullopt instead?
// layouts.push_back(std::make_shared<xla::PjRtLayout>(xla::Layout()));
}
} else {
TF_RETURN_IF_ERROR(maybe_layouts.status());
layouts = *std::move(maybe_layouts);
}

for (int i = 0; i < num_outputs; ++i) {
PjRtArray::PjRtBuffers buffers;
buffers.reserve(num_computations);
Expand Down Expand Up @@ -698,9 +714,9 @@ PjRtLoadedExecutable::Execute(
} else {
sharding = output_shardings_[i];
}
outputs.push_back(*PjRtArray::Create(client_, output_dtypes_[i],
output_shapes_[i], std::move(sharding),
std::move(buffers)));
outputs.push_back(*PjRtArray::Create(
client_, output_dtypes_[i], output_shapes_[i], std::move(sharding),
std::move(buffers), std::move(layouts[i])));
}

ExecuteResult result;
Expand Down
3 changes: 2 additions & 1 deletion xla/python/pjrt_ifrt/pjrt_remap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ PjRtCompatibleClientRemapArrays(
PjRtArray::Create(client, plan.output_specs[i].dtype,
plan.output_specs[i].shape,
plan.output_specs[i].sharding,
std::move(out_buffers_list[i])));
std::move(out_buffers_list[i]),
plan.output_specs[i].layout));
output_arrays.push_back(std::move(output_array));
}
return output_arrays;
Expand Down
2 changes: 1 addition & 1 deletion xla/python/transfer/py_socket_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ void RegisterTransferServerTypes(nanobind::module_& m) {
}
auto arr = xla::ValueOrThrow(xla::ifrt::PjRtArray::Create(
ifrt_client, avals[i].dtype, avals[i].shape, avals[i].sharding,
std::move(buffers)));
std::move(buffers), avals[i].layout));
out.push_back(xla::PyArray::MakeFromIfrtArrayAndSharding(
py_client, traceback, std::move(arr), shardings[i], false, true,
/*skip_checks=*/false));
Expand Down

0 comments on commit 3f0b6c9

Please sign in to comment.