diff --git a/xla/hlo/ir/BUILD b/xla/hlo/ir/BUILD index 53937cca04c2a..eafa1db09662b 100644 --- a/xla/hlo/ir/BUILD +++ b/xla/hlo/ir/BUILD @@ -322,6 +322,7 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/service:hlo_proto_cc", "//xla/tsl/platform:test_main", + "@com_google_absl//absl/log:globals", "@com_google_googletest//:gtest_main", ], ) diff --git a/xla/hlo/ir/collective_device_list.cc b/xla/hlo/ir/collective_device_list.cc index b52514380349f..2cf7f028a3d13 100644 --- a/xla/hlo/ir/collective_device_list.cc +++ b/xla/hlo/ir/collective_device_list.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/log/check.h" @@ -26,6 +25,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/array.h" #include "xla/service/hlo.pb.h" #include "xla/tsl/platform/logging.h" // IWYU pragma: keep #include "xla/xla_data.pb.h" @@ -81,72 +81,35 @@ IotaReplicaGroupList IotaReplicaGroupList::FromProto( proto.iota_transpose_perm().end())); } -CollectiveDeviceList::CollectiveDeviceList( - tsl::protobuf::RepeatedPtrField::const_iterator start, - tsl::protobuf::RepeatedPtrField::const_iterator end) { - replica_groups_shared_ = - std::make_shared>(start, end); - replica_groups_ = replica_groups_shared_.get(); -} - -CollectiveDeviceList::CollectiveDeviceList( - absl::Span replica_groups) { - replica_groups_shared_ = std::make_shared>( - replica_groups.begin(), replica_groups.end()); - replica_groups_ = replica_groups_shared_.get(); -} - -CollectiveDeviceList::CollectiveDeviceList( - absl::Span> replica_groups) { - auto rg_list = std::make_shared>(); - rg_list->reserve(replica_groups.size()); - for (auto g : replica_groups) { - auto& group = rg_list->emplace_back(); - *group.mutable_replica_ids() = {g.begin(), g.end()}; - } - replica_groups_shared_ = std::move(rg_list); - replica_groups_ = replica_groups_shared_.get(); -} - -CollectiveDeviceList::CollectiveDeviceList() { - replica_groups_shared_ = std::make_shared>(); - replica_groups_ = replica_groups_shared_.get(); -} - void CollectiveDeviceList::MaybeMaterializeFullReplicaGroupList() const { - if (replica_groups_ != nullptr) { + if (replica_groups_ != nullptr && !replica_groups_->empty()) { VLOG(10) << "Replica group list already materialized."; return; } - - DCHECK(iota_replica_group_list_.has_value()); + if (!iota_replica_group_list_.has_value()) { + VLOG(1) << "Replica group list not materialized because iota replica group " + "list is not present."; + return; + } VLOG(10) << "Materializing full replica group list"; - auto rg_list = std::make_shared>(); + replica_groups_ = std::make_shared>(); const int64_t num_replica_groups = iota_replica_group_list_->num_replica_groups(); - rg_list->reserve(num_replica_groups); + replica_groups_->reserve(num_replica_groups); - auto array = iota_replica_group_list_->ToArray(); + Array array = iota_replica_group_list_->ToArray(); // Iota replica group list array must only have 2 dimensions. DCHECK_EQ(array.num_dimensions(), 2); const int64_t num_devices_per_group = iota_replica_group_list_->num_devices_per_group(); DCHECK_EQ(array.end() - array.begin(), num_devices_per_group * num_replica_groups); - for (auto it = array.begin(), end = array.end(); it != end; + for (auto it = array.begin(); it != array.end(); it += num_devices_per_group) { - *rg_list->emplace_back().mutable_replica_ids() = { - it, it + num_devices_per_group}; + auto& group = replica_groups_->emplace_back(); + *group.mutable_replica_ids() = {it, it + num_devices_per_group}; } - - replica_groups_shared_ = std::move(rg_list); - replica_groups_ = replica_groups_shared_.get(); -} - -const std::vector& CollectiveDeviceList::replica_groups() const { - MaybeMaterializeFullReplicaGroupList(); - return *replica_groups_; } std::string CollectiveDeviceList::ToString( diff --git a/xla/hlo/ir/collective_device_list.h b/xla/hlo/ir/collective_device_list.h index bf69893d9d59c..63611968efb5f 100644 --- a/xla/hlo/ir/collective_device_list.h +++ b/xla/hlo/ir/collective_device_list.h @@ -84,20 +84,25 @@ class IotaReplicaGroupList { // replica groups, it may be used to represent these lists in compact forms. class CollectiveDeviceList { public: - explicit CollectiveDeviceList(absl::Span replica_groups); + explicit CollectiveDeviceList() = default; + + explicit CollectiveDeviceList(absl::Span replica_groups) + : replica_groups_(std::make_shared>( + replica_groups.begin(), replica_groups.end())) {}; explicit CollectiveDeviceList( - absl::Span> replica_groups); + absl::Span> replica_groups) + : replica_groups_(ToReplicaGroupVector(replica_groups)) {}; + // Replica groups are materialized lazily upon first access. explicit CollectiveDeviceList( const IotaReplicaGroupList& iota_replica_group_list) : iota_replica_group_list_(iota_replica_group_list) {} - // TODO(b/316622399): Remove this constructor and its usage as creating an - // empty collective device list has no meaning. - explicit CollectiveDeviceList(); - - const std::vector& replica_groups() const; + const std::vector& replica_groups() const { + MaybeMaterializeFullReplicaGroupList(); + return *replica_groups_; + } const std::optional& iota_replica_group_list() const { return iota_replica_group_list_; @@ -112,17 +117,33 @@ class CollectiveDeviceList { static CollectiveDeviceList FromProto(const HloInstructionProto& proto); private: - // Construct collective device list from replica group start and end + // Construct collective device list from protobuf replica group start and end // iterators. CollectiveDeviceList( tsl::protobuf::RepeatedPtrField::const_iterator start, - tsl::protobuf::RepeatedPtrField::const_iterator end); + tsl::protobuf::RepeatedPtrField::const_iterator end) + : replica_groups_( + std::make_shared>(start, end)) {}; + + static std::shared_ptr> ToReplicaGroupVector( + absl::Span> replica_groups) { + std::shared_ptr> result = + std::make_shared>(); + result->reserve(replica_groups.size()); + for (const std::vector& g : replica_groups) { + auto& group = result->emplace_back(); + group.mutable_replica_ids()->Add(g.begin(), g.end()); + } + return result; + } + // Load replica groups from iota tile assignment if not already done so. void MaybeMaterializeFullReplicaGroupList() const; - std::optional iota_replica_group_list_ = std::nullopt; - mutable std::shared_ptr> - replica_groups_shared_ = nullptr; - mutable const std::vector* replica_groups_ = nullptr; + + std::optional iota_replica_group_list_; + // shared_ptr for fast copy. + mutable std::shared_ptr> replica_groups_ = + std::make_shared>(); }; } // namespace xla diff --git a/xla/hlo/ir/collective_device_list_test.cc b/xla/hlo/ir/collective_device_list_test.cc index 7480fe54b5fa2..79df08bb3fdea 100644 --- a/xla/hlo/ir/collective_device_list_test.cc +++ b/xla/hlo/ir/collective_device_list_test.cc @@ -38,16 +38,20 @@ CollectiveDeviceListProto CreateDeviceListProto( } TEST(CollectiveDeviceListTest, DefaultListToString) { - CollectiveDeviceList list({{1, 2}, {3, 4}}); - ASSERT_EQ(list.ToString(), "{{1,2},{3,4}}"); + EXPECT_EQ(CollectiveDeviceList().ToString(), "{}"); + EXPECT_EQ(CollectiveDeviceList({{1, 2}, {3, 4}}).ToString(), "{{1,2},{3,4}}"); + EXPECT_EQ(CollectiveDeviceList({{1, 2, 3, 4, 5, 6, 7}}).ToString(), + "{{1,2,3,4,5,6,7}}"); } -TEST(CollectiveDeviceListTest, DefaultListToString2) { - CollectiveDeviceList list({{1, 2, 3, 4, 5, 6, 7}}); - EXPECT_EQ(list.ToString(), "{{1,2,3,4,5,6,7}}"); +TEST(CollectiveDeviceListTest, DeepCopy) { + CollectiveDeviceList orig({{1, 2, 3, 4, 5, 6, 7}}); + CollectiveDeviceList copy = orig; + EXPECT_EQ(&orig.replica_groups(), ©.replica_groups()); } TEST(CollectiveDeviceListTest, DefaultListToProto) { + EXPECT_THAT(CollectiveDeviceList().ToProto().replica_groups().size(), 0); CollectiveDeviceList list({{1, 2}, {3, 4}}); CollectiveDeviceListProto proto = list.ToProto(); EXPECT_THAT(proto.replica_groups().size(), 2);