Skip to content

Commit

Permalink
Add control dependencies to conflicting collectives when decomposing …
Browse files Browse the repository at this point in the history
…collective-permute into send/recv

This is to ensure that conflicting collectives are not scheduled in between the decomposed send/recv ops, which would cause deadlocks.

PiperOrigin-RevId: 721906813
  • Loading branch information
frgossen authored and Google-ML-Automation committed Jan 31, 2025
1 parent 811a86b commit ef55151
Show file tree
Hide file tree
Showing 5 changed files with 703 additions and 17 deletions.
3 changes: 3 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,13 @@ cc_library(
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/hlo/pass:hlo_pass",
"//xla/service/gpu:backend_configs_cc",
"//xla/tsl/platform:errors",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand All @@ -322,6 +324,7 @@ xla_cc_test(
"//xla/hlo/utils:hlo_matchers",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
Expand Down
3 changes: 3 additions & 0 deletions xla/service/collective_ops_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,9 @@ bool IsNonFusionCollective(const HloInstruction* instruction) {
case HloOpcode::kAsyncUpdate:
case HloOpcode::kAsyncDone:
return IsNonFusionCollective(instruction->async_wrapped_instruction());
case HloOpcode::kSend:
case HloOpcode::kRecv:
return !Cast<HloSendRecvInstruction>(instruction)->is_host_transfer();
default:
return false;
}
Expand Down
29 changes: 28 additions & 1 deletion xla/service/collective_ops_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId) {
%param0 = f32[512]{0} parameter(0)
%copy0 = f32[512]{0} copy(param0)
%reshape0 = f32[1,1,512]{2,0,1} reshape(f32[512]{0} %copy0)
%all-gather = f32[1,4,512]{2,0,1} all-gather(f32[1,1,512]{2,0,1} %reshape0), channel_id=3621, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true
%all-gather = f32[1,4,512]{2,0,1} all-gather(f32[1,1,512]{2,0,1} %reshape0),
channel_id=3621, replica_groups={{0,1,2,3}}, dimensions={1},
use_global_device_ids=true
%copy1 = f32[1,4,512]{2,0,1} copy(all-gather)
ROOT root = f32[1,4,512]{2,1,0} copy(%copy1)
})";
Expand All @@ -117,6 +119,31 @@ TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId) {
EXPECT_EQ(IsOrHasCollectiveWithChannelId(all_gather), all_gather);
}

TEST(CollectiveOpsUtilsTest, IsNonFusionCollectiveSendRecv) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY entry_computation {
data = f32[64] parameter(0)
tok = token[] after-all()
recv_ctx = (f32[64], u32[], token[]) recv(tok), channel_id=2,
frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}}
send_ctx = (f32[64], u32[], token[]) send(tok, data), channel_id=2,
frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}}
ROOT root = tuple(send_ctx, recv_ctx)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));

HloInstruction *recv_ctx =
module->entry_computation()->GetInstructionWithName("recv_ctx");
HloInstruction *send_ctx =
module->entry_computation()->GetInstructionWithName("send_ctx");

EXPECT_TRUE(IsNonFusionCollective(recv_ctx));
EXPECT_TRUE(IsNonFusionCollective(send_ctx));
}

TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId2) {
ReplicaGroup group;
for (int64_t i = 0; i < 8; i++) {
Expand Down
Loading

0 comments on commit ef55151

Please sign in to comment.