Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotate decomposed send/recv and conflicting collectives to run them in parallel #22183

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions xla/service/collective_permute_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,25 @@ static std::vector<HloInstruction*> FindAllConflictingCollectives(
return FindAllConflictingCollectives(computation, seed_collectives);
}

static void AddCollectiveStreamAnnotationP2P(
std::vector<HloInstruction*>& instructions) {
xla::FrontendAttributes attributes;
(*attributes.mutable_map())[kCollectiveStreamAttrName] = kCollectiveStreamP2P;
for (HloInstruction* instr : instructions) {
instr->add_frontend_attributes(attributes);
}
}

static void AddCollectiveStreamAnnotationP2P(
std::vector<DecomposedCp>& decomposed) {
std::vector<HloInstruction*> instructions;
for (DecomposedCp& cp : decomposed) {
instructions.push_back(cp.send);
instructions.push_back(cp.recv);
}
AddCollectiveStreamAnnotationP2P(instructions);
}

// Inserts control dependencies to enforce send/recv chain order.
// The order protects from a potential deadlock when every device tries to
// execute recv with no devices executing send - if there are no constraints,
Expand Down Expand Up @@ -492,9 +511,13 @@ absl::StatusOr<bool> CollectivePermuteDecomposer::Run(
} // for MakeInstructionPostOrder

// Find all collectives conflicting with the collective permutes that we
// want to decompose. This is needed to add control dependencies to these
// conflicting collectives so that they cannot move in between the
// decomposed send/recv, which would lead to deadlocks.
// want to decompose. We need this information to achieve two things:
// 1. We want to run these in parallel with non-conflicting collectives,
// e.g. those used on inner sharding strategies. The annotation allows us to
// later execute them on a separate stream.
// 2. We want to add control dependencies to these conflicting collectives
// so that they cannot move in between the decomposed send/recv, which would
// lead to deadlocks.
std::vector<HloInstruction*> conflicing_collectives =
FindAllConflictingCollectives(computation, cps_to_decompose);

Expand All @@ -518,6 +541,16 @@ absl::StatusOr<bool> CollectivePermuteDecomposer::Run(
deco_post_order.push_back(decomposed_ops);
}

// Move all decomposed and conflicting collectives to a separate stream for
// p2p communication. This will allow for overlap of pipeline parallelism
// with other inner sharding strategies. We can remove this when XLA:GPU
// supports multi-stream collectives more generally.
if (pipeline_parallelism_opt_level_ !=
DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_DISABLE) {
AddCollectiveStreamAnnotationP2P(conflicing_collectives);
AddCollectiveStreamAnnotationP2P(deco_post_order);
}

// Enforce order of send/recv pairs at the beginning of the loop body. Also
// enforce all other conflicting collectives to follow the send/recv chain
// so that these cannot be scheduled in between the send/recv, which would
Expand Down
4 changes: 4 additions & 0 deletions xla/service/collective_permute_decomposer.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ limitations under the License.

namespace xla {

inline constexpr absl::string_view kCollectiveStreamAttrName =
"_xla_gpu_collective_stream";
inline constexpr absl::string_view kCollectiveStreamP2P = "p2p";

// CollectivePermuteDecomposer is a pass that (1) converts CollectivePermute
// operations without any cycle in their (source, target) relationship to
// Send/Recv, and (2) annotates the Send/Recv for pipelining with a frontend
Expand Down
96 changes: 96 additions & 0 deletions xla/service/collective_permute_decomposer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,10 @@ TEST_F(DecomposerTest, OneSendRecvWithOneConflictingCollectivePermute) {
// is waiting on its send/recv peer while its peer expects to perform a
// collective-permute.
HloInstruction* cp_cycle = FindInstruction(module.get(), "cp_cycle");
HloInstruction* cp_fwd_recv = FindInstruction(module.get(), "cp_fwd-recv");
HloInstruction* cp_fwd_recv_done =
FindInstruction(module.get(), "cp_fwd-recv-done");
HloInstruction* cp_fwd_send = FindInstruction(module.get(), "cp_fwd-send");
HloInstruction* cp_fwd_send_done =
FindInstruction(module.get(), "cp_fwd-send-done");
ASSERT_THAT(cp_cycle, NotNull());
Expand All @@ -683,6 +685,17 @@ TEST_F(DecomposerTest, OneSendRecvWithOneConflictingCollectivePermute) {
EXPECT_THAT(cp_fwd_send_done->control_predecessors(),
ElementsAre(cp_fwd_recv_done));
EXPECT_THAT(cp_cycle->control_predecessors(), ElementsAre(cp_fwd_send_done));

// Expect all conflicting collectives to be annotated with the collective
// stream attribute.
EXPECT_EQ(
cp_fwd_recv->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
EXPECT_EQ(
cp_fwd_send->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
EXPECT_EQ(cp_cycle->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
}

TEST_F(DecomposerTest, OneSendRecvWithOneConflictingAllReduce) {
Expand Down Expand Up @@ -744,8 +757,10 @@ TEST_F(DecomposerTest, OneSendRecvWithOneConflictingAllReduce) {
// all-reduce. This is to avoid deadlocks where one device is waiting on its
// send/recv peer while its peer expects to perform an all-reduce.
HloInstruction* ar = FindInstruction(module.get(), "ar");
HloInstruction* cp_fwd_recv = FindInstruction(module.get(), "cp_fwd-recv");
HloInstruction* cp_fwd_recv_done =
FindInstruction(module.get(), "cp_fwd-recv-done");
HloInstruction* cp_fwd_send = FindInstruction(module.get(), "cp_fwd-send");
HloInstruction* cp_fwd_send_done =
FindInstruction(module.get(), "cp_fwd-send-done");
ASSERT_THAT(ar, NotNull());
Expand All @@ -754,6 +769,17 @@ TEST_F(DecomposerTest, OneSendRecvWithOneConflictingAllReduce) {
EXPECT_THAT(cp_fwd_send_done->control_predecessors(),
ElementsAre(cp_fwd_recv_done));
EXPECT_THAT(ar->control_predecessors(), ElementsAre(cp_fwd_send_done));

// Expect all conflicting collectives to be annotated with the collective
// stream attribute.
EXPECT_EQ(
cp_fwd_recv->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
EXPECT_EQ(
cp_fwd_send->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
EXPECT_EQ(ar->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
}

TEST_F(DecomposerTest, OneSendRecvWithConflictingSendRecv) {
Expand Down Expand Up @@ -817,8 +843,10 @@ TEST_F(DecomposerTest, OneSendRecvWithConflictingSendRecv) {
// different send/recv communication.
HloInstruction* conflicting_recv = FindInstruction(module.get(), "recv_ctx");
HloInstruction* conflicting_send = FindInstruction(module.get(), "send_ctx");
HloInstruction* cp_fwd_recv = FindInstruction(module.get(), "cp_fwd-recv");
HloInstruction* cp_fwd_recv_done =
FindInstruction(module.get(), "cp_fwd-recv-done");
HloInstruction* cp_fwd_send = FindInstruction(module.get(), "cp_fwd-send");
HloInstruction* cp_fwd_send_done =
FindInstruction(module.get(), "cp_fwd-send-done");
ASSERT_THAT(conflicting_recv, NotNull());
Expand All @@ -831,6 +859,21 @@ TEST_F(DecomposerTest, OneSendRecvWithConflictingSendRecv) {
ElementsAre(cp_fwd_send_done));
EXPECT_THAT(conflicting_send->control_predecessors(),
ElementsAre(cp_fwd_send_done));

// Expect all conflicting collectives to be annotated with the collective
// stream attribute.
EXPECT_EQ(
cp_fwd_recv->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
EXPECT_EQ(
cp_fwd_send->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
EXPECT_EQ(conflicting_recv->frontend_attributes().map().at(
kCollectiveStreamAttrName),
kCollectiveStreamP2P);
EXPECT_EQ(conflicting_send->frontend_attributes().map().at(
kCollectiveStreamAttrName),
kCollectiveStreamP2P);
}

TEST_F(DecomposerTest, OneSendRecvWithNonConflictingAllReduce) {
Expand Down Expand Up @@ -892,8 +935,10 @@ TEST_F(DecomposerTest, OneSendRecvWithNonConflictingAllReduce) {
// non-conflicting all-reduce. These collectivves will not deadlock as their
// NCCL cliques overlap in no more than one rank.
HloInstruction* ar = FindInstruction(module.get(), "ar");
HloInstruction* cp_fwd_recv = FindInstruction(module.get(), "cp_fwd-recv");
HloInstruction* cp_fwd_recv_done =
FindInstruction(module.get(), "cp_fwd-recv-done");
HloInstruction* cp_fwd_send = FindInstruction(module.get(), "cp_fwd-send");
HloInstruction* cp_fwd_send_done =
FindInstruction(module.get(), "cp_fwd-send-done");
ASSERT_THAT(ar, NotNull());
Expand All @@ -902,6 +947,21 @@ TEST_F(DecomposerTest, OneSendRecvWithNonConflictingAllReduce) {
EXPECT_THAT(cp_fwd_send_done->control_predecessors(),
ElementsAre(cp_fwd_recv_done));
EXPECT_THAT(ar->control_predecessors(), ElementsAre());

// Expect all conflicting collectives to be annotated with the collective
// stream attribute.
EXPECT_EQ(
cp_fwd_recv->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
EXPECT_EQ(
cp_fwd_send->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);

// Expect all non-conflicting collectives to not be annotated with the
// collective
// stream attribute.
EXPECT_FALSE(
ar->frontend_attributes().map().contains(kCollectiveStreamAttrName));
}

TEST_F(DecomposerTest, OneSendRecvWithConflictingAndNonConflictingCollectives) {
Expand Down Expand Up @@ -976,8 +1036,10 @@ TEST_F(DecomposerTest, OneSendRecvWithConflictingAndNonConflictingCollectives) {
HloInstruction* cp_cycle = FindInstruction(module.get(), "cp_cycle");
HloInstruction* ar = FindInstruction(module.get(), "ar");
HloInstruction* arc = FindInstruction(module.get(), "arc");
HloInstruction* cp_fwd_recv = FindInstruction(module.get(), "cp_fwd-recv");
HloInstruction* cp_fwd_recv_done =
FindInstruction(module.get(), "cp_fwd-recv-done");
HloInstruction* cp_fwd_send = FindInstruction(module.get(), "cp_fwd-send");
HloInstruction* cp_fwd_send_done =
FindInstruction(module.get(), "cp_fwd-send-done");
ASSERT_THAT(cp_cycle, NotNull());
Expand All @@ -990,6 +1052,24 @@ TEST_F(DecomposerTest, OneSendRecvWithConflictingAndNonConflictingCollectives) {
EXPECT_THAT(cp_cycle->control_predecessors(), ElementsAre(cp_fwd_send_done));
EXPECT_THAT(ar->control_predecessors(), ElementsAre());
EXPECT_THAT(arc->control_predecessors(), ElementsAre(cp_fwd_send_done));

// Expect all conflicting collectives to be annotated with the collective
// stream attribute.
EXPECT_EQ(
cp_fwd_recv->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
EXPECT_EQ(
cp_fwd_send->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
EXPECT_EQ(cp_cycle->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
EXPECT_EQ(arc->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);

// Expect all non-conflicting collectives to not be annotated with the
// collective stream attribute.
EXPECT_FALSE(
ar->frontend_attributes().map().contains(kCollectiveStreamAttrName));
}

TEST_F(DecomposerTest, OneSendRecvWithIndirectlyConflictingCollectives) {
Expand Down Expand Up @@ -1057,8 +1137,10 @@ TEST_F(DecomposerTest, OneSendRecvWithIndirectlyConflictingCollectives) {
// collective-permute.
HloInstruction* cp_cycle = FindInstruction(module.get(), "cp_cycle");
HloInstruction* cp_cycle2 = FindInstruction(module.get(), "cp_cycle2");
HloInstruction* cp_fwd_recv = FindInstruction(module.get(), "cp_fwd-recv");
HloInstruction* cp_fwd_recv_done =
FindInstruction(module.get(), "cp_fwd-recv-done");
HloInstruction* cp_fwd_send = FindInstruction(module.get(), "cp_fwd-send");
HloInstruction* cp_fwd_send_done =
FindInstruction(module.get(), "cp_fwd-send-done");
ASSERT_THAT(cp_cycle, NotNull());
Expand All @@ -1069,6 +1151,20 @@ TEST_F(DecomposerTest, OneSendRecvWithIndirectlyConflictingCollectives) {
ElementsAre(cp_fwd_recv_done));
EXPECT_THAT(cp_cycle->control_predecessors(), ElementsAre(cp_fwd_send_done));
EXPECT_THAT(cp_cycle2->control_predecessors(), ElementsAre(cp_fwd_send_done));

// Expect all conflicting collectives to be annotated with the collective
// stream attribute.
EXPECT_EQ(
cp_fwd_recv->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
EXPECT_EQ(
cp_fwd_send->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
EXPECT_EQ(cp_cycle->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
EXPECT_EQ(
cp_cycle2->frontend_attributes().map().at(kCollectiveStreamAttrName),
kCollectiveStreamP2P);
}

} // namespace
Expand Down
Loading