diff --git a/xla/service/collective_ops_utils.cc b/xla/service/collective_ops_utils.cc index ab2c56dfe0ada6..ba9e6192025b88 100644 --- a/xla/service/collective_ops_utils.cc +++ b/xla/service/collective_ops_utils.cc @@ -757,6 +757,9 @@ bool IsNonFusionCollective(const HloInstruction* instruction) { case HloOpcode::kAsyncUpdate: case HloOpcode::kAsyncDone: return IsNonFusionCollective(instruction->async_wrapped_instruction()); + case HloOpcode::kSend: + case HloOpcode::kRecv: + return !Cast(instruction)->is_host_transfer(); default: return false; } diff --git a/xla/service/collective_ops_utils_test.cc b/xla/service/collective_ops_utils_test.cc index d84e74c09abb5c..c8b24ccf67823d 100644 --- a/xla/service/collective_ops_utils_test.cc +++ b/xla/service/collective_ops_utils_test.cc @@ -104,7 +104,9 @@ TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId) { %param0 = f32[512]{0} parameter(0) %copy0 = f32[512]{0} copy(param0) %reshape0 = f32[1,1,512]{2,0,1} reshape(f32[512]{0} %copy0) - %all-gather = f32[1,4,512]{2,0,1} all-gather(f32[1,1,512]{2,0,1} %reshape0), channel_id=3621, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true + %all-gather = f32[1,4,512]{2,0,1} all-gather(f32[1,1,512]{2,0,1} %reshape0), + channel_id=3621, replica_groups={{0,1,2,3}}, dimensions={1}, + use_global_device_ids=true %copy1 = f32[1,4,512]{2,0,1} copy(all-gather) ROOT root = f32[1,4,512]{2,1,0} copy(%copy1) })"; @@ -117,6 +119,31 @@ TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId) { EXPECT_EQ(IsOrHasCollectiveWithChannelId(all_gather), all_gather); } +TEST(CollectiveOpsUtilsTest, IsNonFusionCollectiveSendRecv) { + absl::string_view hlo_string = R"( + HloModule module + + ENTRY entry_computation { + data = f32[64] parameter(0) + tok = token[] after-all() + recv_ctx = (f32[64], u32[], token[]) recv(tok), channel_id=2, + frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}} + send_ctx = (f32[64], u32[], token[]) send(tok, data), channel_id=2, + frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}} + ROOT root = tuple(send_ctx, recv_ctx) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + HloInstruction *recv_ctx = + module->entry_computation()->GetInstructionWithName("recv_ctx"); + HloInstruction *send_ctx = + module->entry_computation()->GetInstructionWithName("send_ctx"); + + EXPECT_TRUE(IsNonFusionCollective(recv_ctx)); + EXPECT_TRUE(IsNonFusionCollective(send_ctx)); +} + TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId2) { ReplicaGroup group; for (int64_t i = 0; i < 8; i++) { diff --git a/xla/service/collective_permute_decomposer.cc b/xla/service/collective_permute_decomposer.cc index 109f80b3c2c2d5..9c4187fadeac77 100644 --- a/xla/service/collective_permute_decomposer.cc +++ b/xla/service/collective_permute_decomposer.cc @@ -45,13 +45,12 @@ limitations under the License. #include "xla/xla_data.pb.h" namespace xla { -namespace { // Returns true if the CollectivePermute instruction should be transformed // to Send/Recv. We currently limit the transformation to CollectivePermute // operations without any cycle in their (source, target) relationship, // with only one input and without any context data. -bool ShouldDecompose( +static bool ShouldDecompose( const HloCollectivePermuteInstruction& collective_permute, int64_t threshold_in_bytes, const CallGraph& call_graph, DebugOptions::PipelineParallelismOptLevel pipeline_parallelism_opt_level) { @@ -92,21 +91,28 @@ bool ShouldDecompose( // Returns true for a pipelineable collective-permute. As a simple heuristic, // currently only pipeline a collective-permute with a loop input as its send // data. -bool MayPipeline(const HloCollectivePermuteInstruction& collective_permute) { +static bool MayPipeline( + const HloCollectivePermuteInstruction& collective_permute) { const HloInstruction* data = collective_permute.operand(0); return (data->opcode() == HloOpcode::kGetTupleElement && data->operand(0)->opcode() == HloOpcode::kParameter); } +namespace { + // Contains source-target pairs from the permute operation and send and recv // instructions it was decomposed to. struct DecomposedCp { HloInstruction* send; HloInstruction* recv; + HloInstruction* send_done; + HloInstruction* recv_done; std::vector> source_target_pairs; }; -xla::FrontendAttributes ExtractFrontendAttributes( +} // namespace + +static xla::FrontendAttributes ExtractFrontendAttributes( const HloCollectivePermuteInstruction& cp) { const xla::FrontendAttributes& old_attributes = cp.frontend_attributes(); xla::FrontendAttributes attributes; @@ -123,7 +129,7 @@ xla::FrontendAttributes ExtractFrontendAttributes( // the value of the attribute represents the runtime stream to execute the // instruction. Without the frontend attribute, the collective-permute will not // be pipelined. -absl::StatusOr DecomposeCollectivePermute( +static absl::StatusOr DecomposeCollectivePermute( HloCollectivePermuteInstruction* cp, HloComputation* computation, const std::string& pipeline_decision) { absl::string_view cp_name = cp->name(); @@ -170,6 +176,7 @@ absl::StatusOr DecomposeCollectivePermute( // assure that we initiate receival before initiating sending and that receive // done is executed after send is initiated. TF_RETURN_IF_ERROR(recv->AddControlDependencyTo(send)); + TF_RETURN_IF_ERROR(recv_done->AddControlDependencyTo(send_done)); TF_RETURN_IF_ERROR(send->AddControlDependencyTo(recv_done)); if (!pipeline_decision.empty()) { @@ -178,15 +185,16 @@ absl::StatusOr DecomposeCollectivePermute( recv->set_frontend_attribute(kSendRecvPipelineAttr, pipeline_decision); recv_done->set_frontend_attribute(kSendRecvPipelineAttr, pipeline_decision); } - return DecomposedCp{send, recv, cp->source_target_pairs()}; + return DecomposedCp{send, recv, send_done, recv_done, + cp->source_target_pairs()}; } // Checks whether the two collective-permutes for a forward cycle or a backward // cycle for pipelining. If the two collective-permutes form a cycle, returns // a pair of the collective-permutes with the one for the backward edge of the // cycle as the first entry in the pair. -std::optional> +static std::optional> CheckCyclePatterns(HloCollectivePermuteInstruction* cp0, HloCollectivePermuteInstruction* cp1) { const SourceTargetPairs cp0_pairs(cp0->source_target_pairs()); @@ -204,6 +212,166 @@ CheckCyclePatterns(HloCollectivePermuteInstruction* cp0, return std::nullopt; } +namespace { + +struct AbstractReplicaGroups { + // Holds groups of abstract replica ids. + std::vector> groups; + + // Maps abstract replica id to index in groups. + std::vector index_map; + + int64_t get_index(int64_t replica_id) { + while (index_map.size() <= replica_id) index_map.push_back(-1); + return index_map[replica_id]; + } + + void set_index(int64_t replica_id, int64_t index) { + while (index_map.size() <= replica_id) index_map.push_back(-1); + index_map[replica_id] = index; + } + + void merge_groups(int64_t replica_id, int64_t other_replica_id) { + if (get_index(replica_id) == -1 && get_index(other_replica_id) == -1) { + set_index(replica_id, groups.size()); + set_index(other_replica_id, groups.size()); + groups.push_back({replica_id, other_replica_id}); + return; + } + if (get_index(replica_id) == get_index(other_replica_id)) return; + if (get_index(replica_id) == -1) { + std::swap(replica_id, other_replica_id); + } + CHECK_NE(get_index(replica_id), -1); + if (get_index(other_replica_id) == -1) { + set_index(other_replica_id, get_index(replica_id)); + groups[get_index(replica_id)].insert(other_replica_id); + return; + } + CHECK(get_index(replica_id) != -1 && get_index(other_replica_id) != -1 && + get_index(replica_id) != get_index(other_replica_id)); + auto& other_set = groups[get_index(other_replica_id)]; + for (int64_t replica_id_in_other_set : other_set) { + groups[get_index(replica_id)].insert(replica_id_in_other_set); + set_index(replica_id_in_other_set, get_index(replica_id)); + } + other_set.clear(); + } +}; + +} // namespace + +static bool IsConflictingAbstractReplicaGroups(AbstractReplicaGroups& lhs, + AbstractReplicaGroups& rhs) { + std::vector frequency(lhs.groups.size(), 0); + for (auto& rhs_group : rhs.groups) { + std::fill(frequency.begin(), frequency.end(), 0); + for (int64_t rhs_replica_id : rhs_group) { + int64_t i = lhs.get_index(rhs_replica_id); + if (i == -1) continue; + if (++frequency[i] >= 2) return true; + } + } + return false; +} + +static void GetAbstractReplicaGroups(HloInstruction* instr, + AbstractReplicaGroups& groups) { + // Abstract from source-target pairs of collective-permute to abstract replica + // groups. + if (instr->opcode() == HloOpcode::kCollectivePermute) { + auto* cp = Cast(instr); + for (auto& [i, j] : cp->source_target_pairs()) { + groups.merge_groups(i, j); + } + return; + } + + // Abstract from source-target pairs of send/recv to abstract replica groups. + auto add_replica_group = [&groups](const ReplicaGroup& replica_group) { + auto& ids = replica_group.replica_ids(); + if (ids.empty()) return; + int64_t leader_id = ids[0]; + for (int64_t i = 1; i < ids.size(); ++i) { + groups.merge_groups(leader_id, ids[i]); + } + }; + if (instr->opcode() == HloOpcode::kSend || + instr->opcode() == HloOpcode::kRecv) { + auto* sr = Cast(instr); + CHECK(!sr->is_host_transfer()); + std::optional source_target_pairs_str = + sr->frontend_attributes().map().at(kSendRecvSourceTargetPairsAttr); + CHECK(source_target_pairs_str.has_value()); + absl::StatusOr> source_target_pairs = + ParseReplicaGroupsOnly(*source_target_pairs_str); + CHECK(source_target_pairs.ok() && "Expect valid source_target_pairs"); + for (auto& replica_group : *source_target_pairs) { + add_replica_group(replica_group); + } + return; + } + + // Convert normal replica groups to abstract replica groups. + for (auto& replica_group : GetCollectiveReplicaGroups(instr)) { + add_replica_group(replica_group); + } +} + +static std::vector FindAllConflictingCollectives( + const HloComputation* computation, + std::vector& seed_collectives) { + absl::flat_hash_set seen; + + // Get the supremum of all abstract replica groups of the seed collectives + // we're starting with. + AbstractReplicaGroups abstract_replica_groups_supremum; + for (HloInstruction* instr : seed_collectives) { + GetAbstractReplicaGroups(instr, abstract_replica_groups_supremum); + seen.insert(instr); + } + + // Try finding more and more conflicting collectives until we reach a + // fixpoint. This is needed because we may get a coarser supremum with each + // new conflicting collective. + std::vector conflicing_collectives; + bool fixpoint_reached; + do { + fixpoint_reached = true; + + // Look at each collective in the computation. + for (HloInstruction* instr : computation->MakeInstructionPostOrder()) { + // Skip if not a collective or already considered for the supremum. + if (!IsNonFusionCollective(instr) || seen.contains(instr)) continue; + + // Check if this collective is already conflicting with the coarsest + // abstract replica groups. If it does, add to the conflicting collectives + // and update the supremum. + AbstractReplicaGroups groups; + GetAbstractReplicaGroups(instr, groups); + if (IsConflictingAbstractReplicaGroups( + groups, abstract_replica_groups_supremum)) { + conflicing_collectives.push_back(instr); + GetAbstractReplicaGroups(instr, abstract_replica_groups_supremum); + seen.insert(instr); + fixpoint_reached = false; + } + } + } while (!fixpoint_reached); + + return conflicing_collectives; +} + +static std::vector FindAllConflictingCollectives( + HloComputation* computation, + const std::vector& cps) { + std::vector seed_collectives; + for (HloCollectivePermuteInstruction* cp : cps) { + seed_collectives.push_back(static_cast(cp)); + } + return FindAllConflictingCollectives(computation, seed_collectives); +} + // Inserts control dependencies to enforce send/recv chain order. // The order protects from a potential deadlock when every device tries to // execute recv with no devices executing send - if there are no constraints, @@ -211,19 +379,31 @@ CheckCyclePatterns(HloCollectivePermuteInstruction* cp0, // deco_post_order is expected to be post order within a computation. // TODO b/388072780 add second hueristic to enforce back edge before the forward // edge for max performance. -// TODO(b/392684119): Also add control dependencies to conflicting collectives -// other than send/recv. -absl::Status EnforceOrderOfSendRecvChains( +static absl::Status EnforceOrderOfSendRecvChain( std::vector& deco_post_order) { for (size_t i = 1; i < deco_post_order.size(); ++i) { DecomposedCp& cur = deco_post_order[i]; DecomposedCp& prev = deco_post_order[i - 1]; TF_RETURN_IF_ERROR(prev.send->AddControlDependencyTo(cur.recv)); + TF_RETURN_IF_ERROR(prev.send_done->AddControlDependencyTo(cur.recv_done)); } return absl::OkStatus(); } -} // namespace +static absl::Status EnforceOrderOfSendRecvChainRelativeToConflictingCollectives( + std::vector& deco_post_order, + std::vector conflicting_collectives) { + // Find last collective in send/recv chain. + if (deco_post_order.empty()) return absl::OkStatus(); + HloInstruction* last_in_chain = deco_post_order.back().send_done; + + // Add control dependencies from chain to all conflicting collectives. + for (HloInstruction* instr : conflicting_collectives) { + TF_RETURN_IF_ERROR(last_in_chain->AddControlDependencyTo(instr)); + } + + return absl::OkStatus(); +} absl::StatusOr CollectivePermuteDecomposer::Run( HloModule* module, @@ -262,6 +442,7 @@ absl::StatusOr CollectivePermuteDecomposer::Run( while_bodies.insert(instr->while_body()); continue; } + if (instr->opcode() != HloOpcode::kCollectivePermute) { continue; } @@ -303,6 +484,13 @@ absl::StatusOr CollectivePermuteDecomposer::Run( } } // for MakeInstructionPostOrder + // Find all collectives conflicting with the collective permutes that we + // want to decompose. This is needed to add control dependencies to these + // conflicting collectives so that they cannot move in between the + // decomposed send/recv, which would lead to deadlocks. + std::vector conflicing_collectives = + FindAllConflictingCollectives(computation, cps_to_decompose); + // cps to decompose were collected post order, similarly we will collect // the decomposed send/recv pairs. std::vector deco_post_order; @@ -321,7 +509,16 @@ absl::StatusOr CollectivePermuteDecomposer::Run( DecomposeCollectivePermute(cp, computation, pipeline_decision)); deco_post_order.push_back(decomposed_ops); } - TF_RETURN_IF_ERROR(EnforceOrderOfSendRecvChains(deco_post_order)); + + // Enforce order of send/recv pairs at the beginning of the loop body. Also + // enforce all other conflicting collectives to follow the send/recv chain + // so that these cannot be scheduled in between the send/recv, which would + // also lead to deadlocks. + TF_RETURN_IF_ERROR(EnforceOrderOfSendRecvChain(deco_post_order)); + TF_RETURN_IF_ERROR( + EnforceOrderOfSendRecvChainRelativeToConflictingCollectives( + deco_post_order, conflicing_collectives)); + if (!cps_to_decompose.empty()) { changed = true; } diff --git a/xla/service/collective_permute_decomposer_test.cc b/xla/service/collective_permute_decomposer_test.cc index 0abb791ffbb1f5..8cc1a0a3839e79 100644 --- a/xla/service/collective_permute_decomposer_test.cc +++ b/xla/service/collective_permute_decomposer_test.cc @@ -38,6 +38,7 @@ namespace { using ::testing::ElementsAre; using ::testing::HasSubstr; +using ::testing::NotNull; using Pass = CollectivePermuteDecomposer; @@ -322,10 +323,9 @@ void EnsureControlDependency(Decomposed cp) { EXPECT_EQ(cp.send->operand(1), cp.after_all); EXPECT_EQ(cp.recv_done->operand(0), cp.recv); EXPECT_EQ(cp.send_done->operand(0), cp.send); - - EXPECT_THAT(cp.send->control_predecessors(), ElementsAre(cp.recv)) + EXPECT_TRUE(absl::c_contains(cp.send->control_predecessors(), cp.recv)) << "Send should depend on recv when decoposed"; - EXPECT_THAT(cp.recv_done->control_predecessors(), ElementsAre(cp.send)) + EXPECT_TRUE(absl::c_contains(cp.recv_done->control_predecessors(), cp.send)) << "Recv-done should depend on send when decoposed"; } @@ -615,5 +615,455 @@ TEST_F(DecomposerTest, BackwardPipeline2) { << "Per sequence of select operands, cp_fwd should come before cp_back"; } +TEST_F(DecomposerTest, OneSendRecvWithOneConflictingCollectivePermute) { + absl::string_view hlo = R"( + HloModule module + + cond { + param = (u32[], f32[64], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + n = u32[] constant(2) + ROOT result = pred[] compare(i, n), direction=LT + } + + body { + param = (u32[], f32[64], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + data_a = f32[64] get-tuple-element(param), index=1 + data_b = f32[64] get-tuple-element(param), index=2 + + // cp_fwd can be decomposed. + cp_fwd = f32[64] collective-permute(data_a), channel_id=1, + source_target_pairs={{0,1},{1,2},{2,3}} + + // cp_cycle cannot be decomposed and is conflicting with cp_fwd. + cp_cycle = f32[64] collective-permute(data_b), channel_id=2, + source_target_pairs={{0,1},{1,2},{2,3},{3,0}} + c1 = u32[] constant(1) + i_ = u32[] add(i, c1) + + ROOT result = (u32[], f32[64], f32[64]) tuple(i_, cp_fwd, cp_cycle) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + a = f32[] constant(42.0) + data = f32[64] broadcast(a), dimensions={} + while_init = (u32[], f32[64], f32[64]) tuple(c0, data, data) + ROOT while_result = (u32[], f32[64], f32[64]) while(while_init), body=body, + condition=cond + })"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + RunAndCheckHloRewrite( + hlo, + Pass(/*threshold_in_bytes=*/0, + DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE), + true)); + + // Expect the resulting send/recv ops to be strictly ordered before the + // remaining collective-permute. This is to avoid deadlocks where one device + // is waiting on its send/recv peer while its peer expects to perform a + // collective-permute. + HloInstruction* cp_cycle = FindInstruction(module.get(), "cp_cycle"); + HloInstruction* cp_fwd_recv_done = + FindInstruction(module.get(), "cp_fwd-recv-done"); + HloInstruction* cp_fwd_send_done = + FindInstruction(module.get(), "cp_fwd-send-done"); + ASSERT_THAT(cp_cycle, NotNull()); + ASSERT_THAT(cp_fwd_recv_done, NotNull()); + ASSERT_THAT(cp_fwd_send_done, NotNull()); + EXPECT_THAT(cp_fwd_send_done->control_predecessors(), + ElementsAre(cp_fwd_recv_done)); + EXPECT_THAT(cp_cycle->control_predecessors(), ElementsAre(cp_fwd_send_done)); +} + +TEST_F(DecomposerTest, OneSendRecvWithOneConflictingAllReduce) { + absl::string_view hlo = R"( + HloModule module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + cond { + param = (u32[], f32[64], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + n = u32[] constant(2) + ROOT result = pred[] compare(i, n), direction=LT + } + + body { + param = (u32[], f32[64], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + data_a = f32[64] get-tuple-element(param), index=1 + data_b = f32[64] get-tuple-element(param), index=2 + + // cp_fwd can be decomposed. + cp_fwd = f32[64] collective-permute(data_a), channel_id=1, + source_target_pairs={{0,1},{1,2},{2,3}} + + // ar is conflicting with cp_fwd. + ar = f32[64] all-reduce(data_b), channel_id=2, replica_groups={{0,1,2,3}}, + to_apply=add + + c1 = u32[] constant(1) + i_ = u32[] add(i, c1) + + ROOT result = (u32[], f32[64], f32[64]) tuple(i_, cp_fwd, ar) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + a = f32[] constant(42.0) + data = f32[64] broadcast(a), dimensions={} + while_init = (u32[], f32[64], f32[64]) tuple(c0, data, data) + ROOT while_result = (u32[], f32[64], f32[64]) while(while_init), body=body, + condition=cond + })"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + RunAndCheckHloRewrite( + hlo, + Pass(/*threshold_in_bytes=*/0, + DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE), + true)); + + // Expect the resulting send/recv ops to be strictly ordered before the + // all-reduce. This is to avoid deadlocks where one device is waiting on its + // send/recv peer while its peer expects to perform an all-reduce. + HloInstruction* ar = FindInstruction(module.get(), "ar"); + HloInstruction* cp_fwd_recv_done = + FindInstruction(module.get(), "cp_fwd-recv-done"); + HloInstruction* cp_fwd_send_done = + FindInstruction(module.get(), "cp_fwd-send-done"); + ASSERT_THAT(ar, NotNull()); + ASSERT_THAT(cp_fwd_recv_done, NotNull()); + ASSERT_THAT(cp_fwd_send_done, NotNull()); + EXPECT_THAT(cp_fwd_send_done->control_predecessors(), + ElementsAre(cp_fwd_recv_done)); + EXPECT_THAT(ar->control_predecessors(), ElementsAre(cp_fwd_send_done)); +} + +TEST_F(DecomposerTest, OneSendRecvWithConflictingSendRecv) { + absl::string_view hlo = R"( + HloModule module + + cond { + param = (u32[], f32[64], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + n = u32[] constant(2) + ROOT result = pred[] compare(i, n), direction=LT + } + + body { + param = (u32[], f32[64], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + data_a = f32[64] get-tuple-element(param), index=1 + data_b = f32[64] get-tuple-element(param), index=2 + + // cp_fwd can be decomposed. + cp_fwd = f32[64] collective-permute(data_a), channel_id=1, + source_target_pairs={{0,1},{1,2},{2,3}} + + // These send/recv ops are conflicting with cp_fwd. + tok = token[] after-all() + recv_ctx = (f32[64], u32[], token[]) recv(tok), channel_id=2, + frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}} + recv_done = (f32[64], token[]) recv-done(recv_ctx), channel_id=2 + send_ctx = (f32[64], u32[], token[]) send(tok, data_b), channel_id=2, + frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}} + send_done = token[] send-done(send_ctx), channel_id=2 + recv_data = f32[64] get-tuple-element(recv_done), index=0 + + c1 = u32[] constant(1) + i_ = u32[] add(i, c1) + + ROOT result = (u32[], f32[64], f32[64]) tuple(i_, cp_fwd, recv_data) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + a = f32[] constant(42.0) + data = f32[64] broadcast(a), dimensions={} + while_init = (u32[], f32[64], f32[64]) tuple(c0, data, data) + ROOT while_result = (u32[], f32[64], f32[64]) while(while_init), body=body, + condition=cond + })"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + RunAndCheckHloRewrite( + hlo, + Pass(/*threshold_in_bytes=*/0, + DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE), + true)); + + // Expect decomposed send/recv ops to be strictly ordered before the + // preexisting send/recv ops. This is to avoid deadlocks where one device is + // waiting on its decomposed send/recv peer while its peer is stuck in some + // different send/recv communication. + HloInstruction* conflicting_recv = FindInstruction(module.get(), "recv_ctx"); + HloInstruction* conflicting_send = FindInstruction(module.get(), "send_ctx"); + HloInstruction* cp_fwd_recv_done = + FindInstruction(module.get(), "cp_fwd-recv-done"); + HloInstruction* cp_fwd_send_done = + FindInstruction(module.get(), "cp_fwd-send-done"); + ASSERT_THAT(conflicting_recv, NotNull()); + ASSERT_THAT(conflicting_send, NotNull()); + ASSERT_THAT(cp_fwd_recv_done, NotNull()); + ASSERT_THAT(cp_fwd_send_done, NotNull()); + EXPECT_THAT(cp_fwd_send_done->control_predecessors(), + ElementsAre(cp_fwd_recv_done)); + EXPECT_THAT(conflicting_recv->control_predecessors(), + ElementsAre(cp_fwd_send_done)); + EXPECT_THAT(conflicting_send->control_predecessors(), + ElementsAre(cp_fwd_send_done)); +} + +TEST_F(DecomposerTest, OneSendRecvWithNonConflictingAllReduce) { + absl::string_view hlo = R"( + HloModule module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + cond { + param = (u32[], f32[64], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + n = u32[] constant(2) + ROOT result = pred[] compare(i, n), direction=LT + } + + body { + param = (u32[], f32[64], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + data_a = f32[64] get-tuple-element(param), index=1 + data_b = f32[64] get-tuple-element(param), index=2 + + // cp_fwd can be decomposed. + cp_fwd = f32[64] collective-permute(data_a), channel_id=1, + source_target_pairs={{0,2},{1,3}} + + // ar is non-conflicting with cp_fwd. Cliques overlap in no more than one + // rank. + ar = f32[64] all-reduce(data_b), channel_id=2, replica_groups={{0,1},{2,3}}, + to_apply=add + c1 = u32[] constant(1) + i_ = u32[] add(i, c1) + + ROOT result = (u32[], f32[64], f32[64]) tuple(i_, cp_fwd, ar) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + a = f32[] constant(42.0) + data = f32[64] broadcast(a), dimensions={} + while_init = (u32[], f32[64], f32[64]) tuple(c0, data, data) + ROOT while_result = (u32[], f32[64], f32[64]) while(while_init), body=body, + condition=cond + })"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + RunAndCheckHloRewrite( + hlo, + Pass(/*threshold_in_bytes=*/0, + DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE), + true)); + + // Expect decomposed send/recv ops to be unordered wrt. to preexisting + // non-conflicting all-reduce. These collectivves will not deadlock as their + // NCCL cliques overlap in no more than one rank. + HloInstruction* ar = FindInstruction(module.get(), "ar"); + HloInstruction* cp_fwd_recv_done = + FindInstruction(module.get(), "cp_fwd-recv-done"); + HloInstruction* cp_fwd_send_done = + FindInstruction(module.get(), "cp_fwd-send-done"); + ASSERT_THAT(ar, NotNull()); + ASSERT_THAT(cp_fwd_recv_done, NotNull()); + ASSERT_THAT(cp_fwd_send_done, NotNull()); + EXPECT_THAT(cp_fwd_send_done->control_predecessors(), + ElementsAre(cp_fwd_recv_done)); + EXPECT_THAT(ar->control_predecessors(), ElementsAre()); +} + +TEST_F(DecomposerTest, OneSendRecvWithConflictingAndNonConflictingCollectives) { + absl::string_view hlo = R"( + HloModule module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + cond { + param = (u32[], f32[64], f32[64], f32[64], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + n = u32[] constant(2) + ROOT result = pred[] compare(i, n), direction=LT + } + + body { + param = (u32[], f32[64], f32[64], f32[64], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + data_a = f32[64] get-tuple-element(param), index=1 + data_b = f32[64] get-tuple-element(param), index=2 + + // cp_fwd can be decomposed. + cp_fwd = f32[64] collective-permute(data_a), channel_id=1, + source_target_pairs={{0,2},{1,3}} + + // cp_cycle is conflicting with cp_fwd. + cp_cycle = f32[64] collective-permute(data_b), channel_id=2, + source_target_pairs={{0,1},{1,2},{2,3},{3,0}} + + // ar is non-conflicting with cp_fwd. Cliques overlap in no more than one + // rank. + ar = f32[64] all-reduce(data_b), channel_id=3, + replica_groups={{0},{1},{2},{3}}, to_apply=add + + // arc is conflicting with cp_fwd. + arc = f32[64] all-reduce(data_b), channel_id=4, replica_groups={{0,1,2,3}}, + to_apply=add + + c1 = u32[] constant(1) + i_ = u32[] add(i, c1) + + ROOT result = (u32[], f32[64], f32[64], f32[64], f32[64]) tuple(i_, cp_fwd, + cp_cycle, ar, arc) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + a = f32[] constant(42.0) + data = f32[64] broadcast(a), dimensions={} + while_init = (u32[], f32[64], f32[64], f32[64], f32[64]) tuple(c0, data, + data, data, data) + ROOT while_result = (u32[], f32[64], f32[64], f32[64], f32[64]) + while(while_init), body=body, condition=cond + })"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + RunAndCheckHloRewrite( + hlo, + Pass(/*threshold_in_bytes=*/0, + DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE), + true)); + + // Expect decomposed send/recv ops to be strictly ordered before the + // conflicting all-reduce arc and the conflicting collective-permute cp_cycle. + // Expect them to be unordered wrt. to the non-conflicting all-reduce ar. + HloInstruction* cp_cycle = FindInstruction(module.get(), "cp_cycle"); + HloInstruction* ar = FindInstruction(module.get(), "ar"); + HloInstruction* arc = FindInstruction(module.get(), "arc"); + HloInstruction* cp_fwd_recv_done = + FindInstruction(module.get(), "cp_fwd-recv-done"); + HloInstruction* cp_fwd_send_done = + FindInstruction(module.get(), "cp_fwd-send-done"); + ASSERT_THAT(cp_cycle, NotNull()); + ASSERT_THAT(ar, NotNull()); + ASSERT_THAT(arc, NotNull()); + ASSERT_THAT(cp_fwd_recv_done, NotNull()); + ASSERT_THAT(cp_fwd_send_done, NotNull()); + EXPECT_THAT(cp_fwd_send_done->control_predecessors(), + ElementsAre(cp_fwd_recv_done)); + EXPECT_THAT(cp_cycle->control_predecessors(), ElementsAre(cp_fwd_send_done)); + EXPECT_THAT(ar->control_predecessors(), ElementsAre()); + EXPECT_THAT(arc->control_predecessors(), ElementsAre(cp_fwd_send_done)); +} + +TEST_F(DecomposerTest, OneSendRecvWithIndirectlyConflictingCollectives) { + absl::string_view hlo = R"( + HloModule module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + cond { + param = (u32[], f32[64], f32[64], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + n = u32[] constant(2) + ROOT result = pred[] compare(i, n), direction=LT + } + + body { + param = (u32[], f32[64], f32[64], f32[64]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + data_a = f32[64] get-tuple-element(param), index=1 + data_b = f32[64] get-tuple-element(param), index=2 + + // cp_fwd can be decomposed. + cp_fwd = f32[64] collective-permute(data_a), channel_id=1, + source_target_pairs={{0,1},{1,2},{2,3}} + + // These collective-permute ops are conflicting with cp_fwd, some through + // indirection. + cp_cycle = f32[64] collective-permute(data_b), channel_id=2, + source_target_pairs={{4,5},{5,6},{6,7},{7,4}} + cp_cycle2 = f32[64] collective-permute(data_b), channel_id=3, + source_target_pairs={{2,3},{3,4},{4,5},{5,2}} + + c1 = u32[] constant(1) + i_ = u32[] add(i, c1) + + ROOT result = (u32[], f32[64], f32[64], f32[64]) tuple(i_, cp_fwd, cp_cycle, + cp_cycle2) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + a = f32[] constant(42.0) + data = f32[64] broadcast(a), dimensions={} + while_init = (u32[], f32[64], f32[64], f32[64]) tuple(c0, data, data, data) + ROOT while_result = (u32[], f32[64], f32[64], f32[64]) while(while_init), body=body, condition=cond + })"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + RunAndCheckHloRewrite( + hlo, + Pass(/*threshold_in_bytes=*/0, + DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE), + true)); + + // Expect send/recv ops to be strictly ordered before the conflicting + // collective-permute ops. This is to avoid deadlocks where one device is + // waiting on its send/recv peer while its peer is stuck in a different + // collective-permute. + HloInstruction* cp_cycle = FindInstruction(module.get(), "cp_cycle"); + HloInstruction* cp_cycle2 = FindInstruction(module.get(), "cp_cycle2"); + HloInstruction* cp_fwd_recv_done = + FindInstruction(module.get(), "cp_fwd-recv-done"); + HloInstruction* cp_fwd_send_done = + FindInstruction(module.get(), "cp_fwd-send-done"); + ASSERT_THAT(cp_cycle, NotNull()); + ASSERT_THAT(cp_cycle2, NotNull()); + ASSERT_THAT(cp_fwd_recv_done, NotNull()); + ASSERT_THAT(cp_fwd_send_done, NotNull()); + EXPECT_THAT(cp_fwd_send_done->control_predecessors(), + ElementsAre(cp_fwd_recv_done)); + EXPECT_THAT(cp_cycle->control_predecessors(), ElementsAre(cp_fwd_send_done)); + EXPECT_THAT(cp_cycle2->control_predecessors(), ElementsAre(cp_fwd_send_done)); +} + } // namespace } // namespace xla