Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary pointers in CollectiveDeviceList. #22203

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/hlo/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ xla_cc_test(
"//xla:xla_data_proto_cc",
"//xla/service:hlo_proto_cc",
"//xla/tsl/platform:test_main",
"@com_google_absl//absl/log:globals",
"@com_google_googletest//:gtest_main",
],
)
63 changes: 13 additions & 50 deletions xla/hlo/ir/collective_device_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ limitations under the License.
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xla/array.h"
#include "xla/service/hlo.pb.h"
#include "xla/tsl/platform/logging.h" // IWYU pragma: keep
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -81,72 +81,35 @@ IotaReplicaGroupList IotaReplicaGroupList::FromProto(
proto.iota_transpose_perm().end()));
}

CollectiveDeviceList::CollectiveDeviceList(
tsl::protobuf::RepeatedPtrField<ReplicaGroup>::const_iterator start,
tsl::protobuf::RepeatedPtrField<ReplicaGroup>::const_iterator end) {
replica_groups_shared_ =
std::make_shared<std::vector<ReplicaGroup>>(start, end);
replica_groups_ = replica_groups_shared_.get();
}

CollectiveDeviceList::CollectiveDeviceList(
absl::Span<const ReplicaGroup> replica_groups) {
replica_groups_shared_ = std::make_shared<std::vector<ReplicaGroup>>(
replica_groups.begin(), replica_groups.end());
replica_groups_ = replica_groups_shared_.get();
}

CollectiveDeviceList::CollectiveDeviceList(
absl::Span<const std::vector<int64_t>> replica_groups) {
auto rg_list = std::make_shared<std::vector<ReplicaGroup>>();
rg_list->reserve(replica_groups.size());
for (auto g : replica_groups) {
auto& group = rg_list->emplace_back();
*group.mutable_replica_ids() = {g.begin(), g.end()};
}
replica_groups_shared_ = std::move(rg_list);
replica_groups_ = replica_groups_shared_.get();
}

CollectiveDeviceList::CollectiveDeviceList() {
replica_groups_shared_ = std::make_shared<std::vector<ReplicaGroup>>();
replica_groups_ = replica_groups_shared_.get();
}

void CollectiveDeviceList::MaybeMaterializeFullReplicaGroupList() const {
if (replica_groups_ != nullptr) {
if (replica_groups_ != nullptr && !replica_groups_->empty()) {
VLOG(10) << "Replica group list already materialized.";
return;
}

DCHECK(iota_replica_group_list_.has_value());
if (!iota_replica_group_list_.has_value()) {
VLOG(1) << "Replica group list not materialized because iota replica group "
"list is not present.";
return;
}
VLOG(10) << "Materializing full replica group list";

auto rg_list = std::make_shared<std::vector<ReplicaGroup>>();
replica_groups_ = std::make_shared<std::vector<ReplicaGroup>>();
const int64_t num_replica_groups =
iota_replica_group_list_->num_replica_groups();
rg_list->reserve(num_replica_groups);
replica_groups_->reserve(num_replica_groups);

auto array = iota_replica_group_list_->ToArray();
Array<int64_t> array = iota_replica_group_list_->ToArray();
// Iota replica group list array must only have 2 dimensions.
DCHECK_EQ(array.num_dimensions(), 2);
const int64_t num_devices_per_group =
iota_replica_group_list_->num_devices_per_group();
DCHECK_EQ(array.end() - array.begin(),
num_devices_per_group * num_replica_groups);
for (auto it = array.begin(), end = array.end(); it != end;
for (auto it = array.begin(); it != array.end();
it += num_devices_per_group) {
*rg_list->emplace_back().mutable_replica_ids() = {
it, it + num_devices_per_group};
auto& group = replica_groups_->emplace_back();
*group.mutable_replica_ids() = {it, it + num_devices_per_group};
}

replica_groups_shared_ = std::move(rg_list);
replica_groups_ = replica_groups_shared_.get();
}

const std::vector<ReplicaGroup>& CollectiveDeviceList::replica_groups() const {
MaybeMaterializeFullReplicaGroupList();
return *replica_groups_;
}

std::string CollectiveDeviceList::ToString(
Expand Down
47 changes: 34 additions & 13 deletions xla/hlo/ir/collective_device_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,25 @@ class IotaReplicaGroupList {
// replica groups, it may be used to represent these lists in compact forms.
class CollectiveDeviceList {
public:
explicit CollectiveDeviceList(absl::Span<const ReplicaGroup> replica_groups);
explicit CollectiveDeviceList() = default;

explicit CollectiveDeviceList(absl::Span<const ReplicaGroup> replica_groups)
: replica_groups_(std::make_shared<std::vector<ReplicaGroup>>(
replica_groups.begin(), replica_groups.end())) {};

explicit CollectiveDeviceList(
absl::Span<const std::vector<int64_t>> replica_groups);
absl::Span<const std::vector<int64_t>> replica_groups)
: replica_groups_(ToReplicaGroupVector(replica_groups)) {};

// Replica groups are materialized lazily upon first access.
explicit CollectiveDeviceList(
const IotaReplicaGroupList& iota_replica_group_list)
: iota_replica_group_list_(iota_replica_group_list) {}

// TODO(b/316622399): Remove this constructor and its usage as creating an
// empty collective device list has no meaning.
explicit CollectiveDeviceList();

const std::vector<ReplicaGroup>& replica_groups() const;
const std::vector<ReplicaGroup>& replica_groups() const {
MaybeMaterializeFullReplicaGroupList();
return *replica_groups_;
}

const std::optional<IotaReplicaGroupList>& iota_replica_group_list() const {
return iota_replica_group_list_;
Expand All @@ -112,17 +117,33 @@ class CollectiveDeviceList {
static CollectiveDeviceList FromProto(const HloInstructionProto& proto);

private:
// Construct collective device list from replica group start and end
// Construct collective device list from protobuf replica group start and end
// iterators.
CollectiveDeviceList(
tsl::protobuf::RepeatedPtrField<ReplicaGroup>::const_iterator start,
tsl::protobuf::RepeatedPtrField<ReplicaGroup>::const_iterator end);
tsl::protobuf::RepeatedPtrField<ReplicaGroup>::const_iterator end)
: replica_groups_(
std::make_shared<std::vector<ReplicaGroup>>(start, end)) {};

static std::shared_ptr<std::vector<ReplicaGroup>> ToReplicaGroupVector(
absl::Span<const std::vector<int64_t>> replica_groups) {
std::shared_ptr<std::vector<ReplicaGroup>> result =
std::make_shared<std::vector<ReplicaGroup>>();
result->reserve(replica_groups.size());
for (const std::vector<int64_t>& g : replica_groups) {
auto& group = result->emplace_back();
group.mutable_replica_ids()->Add(g.begin(), g.end());
}
return result;
}

// Load replica groups from iota tile assignment if not already done so.
void MaybeMaterializeFullReplicaGroupList() const;
std::optional<IotaReplicaGroupList> iota_replica_group_list_ = std::nullopt;
mutable std::shared_ptr<const std::vector<ReplicaGroup>>
replica_groups_shared_ = nullptr;
mutable const std::vector<ReplicaGroup>* replica_groups_ = nullptr;

std::optional<IotaReplicaGroupList> iota_replica_group_list_;
// shared_ptr for fast copy.
mutable std::shared_ptr<std::vector<ReplicaGroup>> replica_groups_ =
std::make_shared<std::vector<ReplicaGroup>>();
};

} // namespace xla
Expand Down
14 changes: 9 additions & 5 deletions xla/hlo/ir/collective_device_list_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,20 @@ CollectiveDeviceListProto CreateDeviceListProto(
}

TEST(CollectiveDeviceListTest, DefaultListToString) {
CollectiveDeviceList list({{1, 2}, {3, 4}});
ASSERT_EQ(list.ToString(), "{{1,2},{3,4}}");
EXPECT_EQ(CollectiveDeviceList().ToString(), "{}");
EXPECT_EQ(CollectiveDeviceList({{1, 2}, {3, 4}}).ToString(), "{{1,2},{3,4}}");
EXPECT_EQ(CollectiveDeviceList({{1, 2, 3, 4, 5, 6, 7}}).ToString(),
"{{1,2,3,4,5,6,7}}");
}

TEST(CollectiveDeviceListTest, DefaultListToString2) {
CollectiveDeviceList list({{1, 2, 3, 4, 5, 6, 7}});
EXPECT_EQ(list.ToString(), "{{1,2,3,4,5,6,7}}");
TEST(CollectiveDeviceListTest, DeepCopy) {
CollectiveDeviceList orig({{1, 2, 3, 4, 5, 6, 7}});
CollectiveDeviceList copy = orig;
EXPECT_EQ(&orig.replica_groups(), &copy.replica_groups());
}

TEST(CollectiveDeviceListTest, DefaultListToProto) {
EXPECT_THAT(CollectiveDeviceList().ToProto().replica_groups().size(), 0);
CollectiveDeviceList list({{1, 2}, {3, 4}});
CollectiveDeviceListProto proto = list.ToProto();
EXPECT_THAT(proto.replica_groups().size(), 2);
Expand Down
Loading