Skip to content

Commit

Permalink
Add control dependencies to conflicting collectives when decomposing …
Browse files Browse the repository at this point in the history
…collective-permute into send/recv

This is to ensure that conflicting collectives are not scheduled in between the decomposed send/recv ops, which would cause deadlocks.

PiperOrigin-RevId: 721906813
  • Loading branch information
frgossen authored and Google-ML-Automation committed Jan 31, 2025
1 parent 27a43c1 commit 112f7e8
Show file tree
Hide file tree
Showing 4 changed files with 694 additions and 17 deletions.
3 changes: 3 additions & 0 deletions xla/service/collective_ops_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,9 @@ bool IsNonFusionCollective(const HloInstruction* instruction) {
case HloOpcode::kAsyncUpdate:
case HloOpcode::kAsyncDone:
return IsNonFusionCollective(instruction->async_wrapped_instruction());
case HloOpcode::kSend:
case HloOpcode::kRecv:
return !Cast<HloSendRecvInstruction>(instruction)->is_host_transfer();
default:
return false;
}
Expand Down
29 changes: 28 additions & 1 deletion xla/service/collective_ops_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId) {
%param0 = f32[512]{0} parameter(0)
%copy0 = f32[512]{0} copy(param0)
%reshape0 = f32[1,1,512]{2,0,1} reshape(f32[512]{0} %copy0)
%all-gather = f32[1,4,512]{2,0,1} all-gather(f32[1,1,512]{2,0,1} %reshape0), channel_id=3621, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true
%all-gather = f32[1,4,512]{2,0,1} all-gather(f32[1,1,512]{2,0,1} %reshape0),
channel_id=3621, replica_groups={{0,1,2,3}}, dimensions={1},
use_global_device_ids=true
%copy1 = f32[1,4,512]{2,0,1} copy(all-gather)
ROOT root = f32[1,4,512]{2,1,0} copy(%copy1)
})";
Expand All @@ -117,6 +119,31 @@ TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId) {
EXPECT_EQ(IsOrHasCollectiveWithChannelId(all_gather), all_gather);
}

TEST(CollectiveOpsUtilsTest, IsNonFusionCollectiveSendRecv) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY entry_computation {
data = f32[64] parameter(0)
tok = token[] after-all()
recv_ctx = (f32[64], u32[], token[]) recv(tok), channel_id=2,
frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}}
send_ctx = (f32[64], u32[], token[]) send(tok, data), channel_id=2,
frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}}
ROOT root = tuple(send_ctx, recv_ctx)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));

HloInstruction *recv_ctx =
module->entry_computation()->GetInstructionWithName("recv_ctx");
HloInstruction *send_ctx =
module->entry_computation()->GetInstructionWithName("send_ctx");

EXPECT_TRUE(IsNonFusionCollective(recv_ctx));
EXPECT_TRUE(IsNonFusionCollective(send_ctx));
}

TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId2) {
ReplicaGroup group;
for (int64_t i = 0; i < 8; i++) {
Expand Down
223 changes: 210 additions & 13 deletions xla/service/collective_permute_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,12 @@ limitations under the License.
#include "xla/xla_data.pb.h"

namespace xla {
namespace {

// Returns true if the CollectivePermute instruction should be transformed
// to Send/Recv. We currently limit the transformation to CollectivePermute
// operations without any cycle in their (source, target) relationship,
// with only one input and without any context data.
bool ShouldDecompose(
static bool ShouldDecompose(
const HloCollectivePermuteInstruction& collective_permute,
int64_t threshold_in_bytes, const CallGraph& call_graph,
DebugOptions::PipelineParallelismOptLevel pipeline_parallelism_opt_level) {
Expand Down Expand Up @@ -92,21 +91,28 @@ bool ShouldDecompose(
// Returns true for a pipelineable collective-permute. As a simple heuristic,
// currently only pipeline a collective-permute with a loop input as its send
// data.
bool MayPipeline(const HloCollectivePermuteInstruction& collective_permute) {
static bool MayPipeline(
const HloCollectivePermuteInstruction& collective_permute) {
const HloInstruction* data = collective_permute.operand(0);
return (data->opcode() == HloOpcode::kGetTupleElement &&
data->operand(0)->opcode() == HloOpcode::kParameter);
}

namespace {

// Contains source-target pairs from the permute operation and send and recv
// instructions it was decomposed to.
struct DecomposedCp {
HloInstruction* send;
HloInstruction* recv;
HloInstruction* send_done;
HloInstruction* recv_done;
std::vector<std::pair<int64_t, int64_t>> source_target_pairs;
};

xla::FrontendAttributes ExtractFrontendAttributes(
} // namespace

static xla::FrontendAttributes ExtractFrontendAttributes(
const HloCollectivePermuteInstruction& cp) {
const xla::FrontendAttributes& old_attributes = cp.frontend_attributes();
xla::FrontendAttributes attributes;
Expand All @@ -123,7 +129,7 @@ xla::FrontendAttributes ExtractFrontendAttributes(
// the value of the attribute represents the runtime stream to execute the
// instruction. Without the frontend attribute, the collective-permute will not
// be pipelined.
absl::StatusOr<DecomposedCp> DecomposeCollectivePermute(
static absl::StatusOr<DecomposedCp> DecomposeCollectivePermute(
HloCollectivePermuteInstruction* cp, HloComputation* computation,
const std::string& pipeline_decision) {
absl::string_view cp_name = cp->name();
Expand Down Expand Up @@ -170,6 +176,7 @@ absl::StatusOr<DecomposedCp> DecomposeCollectivePermute(
// assure that we initiate receival before initiating sending and that receive
// done is executed after send is initiated.
TF_RETURN_IF_ERROR(recv->AddControlDependencyTo(send));
TF_RETURN_IF_ERROR(recv_done->AddControlDependencyTo(send_done));
TF_RETURN_IF_ERROR(send->AddControlDependencyTo(recv_done));

if (!pipeline_decision.empty()) {
Expand All @@ -178,15 +185,16 @@ absl::StatusOr<DecomposedCp> DecomposeCollectivePermute(
recv->set_frontend_attribute(kSendRecvPipelineAttr, pipeline_decision);
recv_done->set_frontend_attribute(kSendRecvPipelineAttr, pipeline_decision);
}
return DecomposedCp{send, recv, cp->source_target_pairs()};
return DecomposedCp{send, recv, send_done, recv_done,
cp->source_target_pairs()};
}

// Checks whether the two collective-permutes for a forward cycle or a backward
// cycle for pipelining. If the two collective-permutes form a cycle, returns
// a pair of the collective-permutes with the one for the backward edge of the
// cycle as the first entry in the pair.
std::optional<std::pair<HloCollectivePermuteInstruction*,
HloCollectivePermuteInstruction*>>
static std::optional<std::pair<HloCollectivePermuteInstruction*,
HloCollectivePermuteInstruction*>>
CheckCyclePatterns(HloCollectivePermuteInstruction* cp0,
HloCollectivePermuteInstruction* cp1) {
const SourceTargetPairs cp0_pairs(cp0->source_target_pairs());
Expand All @@ -204,26 +212,198 @@ CheckCyclePatterns(HloCollectivePermuteInstruction* cp0,
return std::nullopt;
}

namespace {

struct AbstractReplicaGroups {
// Holds groups of abstract replica ids.
std::vector<absl::flat_hash_set<int64_t>> groups;

// Maps abstract replica id to index in groups.
std::vector<int64_t> index_map;

int64_t get_index(int64_t replica_id) {
while (index_map.size() <= replica_id) index_map.push_back(-1);
return index_map[replica_id];
}

void set_index(int64_t replica_id, int64_t index) {
while (index_map.size() <= replica_id) index_map.push_back(-1);
index_map[replica_id] = index;
}

void merge_groups(int64_t replica_id, int64_t other_replica_id) {
if (get_index(replica_id) == -1 && get_index(other_replica_id) == -1) {
set_index(replica_id, groups.size());
set_index(other_replica_id, groups.size());
groups.push_back({replica_id, other_replica_id});
return;
}
if (get_index(replica_id) == get_index(other_replica_id)) return;
if (get_index(replica_id) == -1) {
std::swap(replica_id, other_replica_id);
}
CHECK_NE(get_index(replica_id), -1);
if (get_index(other_replica_id) == -1) {
set_index(other_replica_id, get_index(replica_id));
groups[get_index(replica_id)].insert(other_replica_id);
return;
}
CHECK(get_index(replica_id) != -1 && get_index(other_replica_id) != -1 &&
get_index(replica_id) != get_index(other_replica_id));
auto& other_set = groups[get_index(other_replica_id)];
for (int64_t replica_id_in_other_set : other_set) {
groups[get_index(replica_id)].insert(replica_id_in_other_set);
set_index(replica_id_in_other_set, get_index(replica_id));
}
other_set.clear();
}
};

} // namespace

static bool IsConflictingAbstractReplicaGroups(AbstractReplicaGroups& lhs,
AbstractReplicaGroups& rhs) {
std::vector<int64_t> frequency(lhs.groups.size(), 0);
for (auto& rhs_group : rhs.groups) {
std::fill(frequency.begin(), frequency.end(), 0);
for (int64_t rhs_replica_id : rhs_group) {
int64_t i = lhs.get_index(rhs_replica_id);
if (i == -1) continue;
if (++frequency[i] >= 2) return true;
}
}
return false;
}

static void GetAbstractReplicaGroups(HloInstruction* instr,
AbstractReplicaGroups& groups) {
// Abstract from source-target pairs of collective-permute to abstract replica
// groups.
if (instr->opcode() == HloOpcode::kCollectivePermute) {
auto* cp = Cast<HloCollectivePermuteInstruction>(instr);
for (auto& [i, j] : cp->source_target_pairs()) {
groups.merge_groups(i, j);
}
return;
}

// Abstract from source-target pairs of send/recv to abstract replica groups.
auto add_replica_group = [&groups](const ReplicaGroup& replica_group) {
auto& ids = replica_group.replica_ids();
if (ids.empty()) return;
int64_t leader_id = ids[0];
for (int64_t i = 1; i < ids.size(); ++i) {
groups.merge_groups(leader_id, ids[i]);
}
};
if (instr->opcode() == HloOpcode::kSend ||
instr->opcode() == HloOpcode::kRecv) {
auto* sr = Cast<HloSendRecvInstruction>(instr);
CHECK(!sr->is_host_transfer());
std::optional<std::string> source_target_pairs_str =
sr->frontend_attributes().map().at(kSendRecvSourceTargetPairsAttr);
CHECK(source_target_pairs_str.has_value());
absl::StatusOr<std::vector<ReplicaGroup>> source_target_pairs =
ParseReplicaGroupsOnly(*source_target_pairs_str);
CHECK(source_target_pairs.ok() && "Expect valid source_target_pairs");
for (auto& replica_group : *source_target_pairs) {
add_replica_group(replica_group);
}
return;
}

// Convert normal replica groups to abstract replica groups.
for (auto& replica_group : GetCollectiveReplicaGroups(instr)) {
add_replica_group(replica_group);
}
}

static std::vector<HloInstruction*> FindAllConflictingCollectives(
const HloComputation* computation,
std::vector<HloInstruction*>& seed_collectives) {
absl::flat_hash_set<HloInstruction*> seen;

// Get the supremum of all abstract replica groups of the seed collectives
// we're starting with.
AbstractReplicaGroups abstract_replica_groups_supremum;
for (HloInstruction* instr : seed_collectives) {
GetAbstractReplicaGroups(instr, abstract_replica_groups_supremum);
seen.insert(instr);
}

// Try finding more and more conflicting collectives until we reach a
// fixpoint. This is needed because we may get a coarser supremum with each
// new conflicting collective.
std::vector<HloInstruction*> conflicing_collectives;
bool fixpoint_reached;
do {
fixpoint_reached = true;

// Look at each collective in the computation.
for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
// Skip if not a collective or already considered for the supremum.
if (!IsNonFusionCollective(instr) || seen.contains(instr)) continue;

// Check if this collective is already conflicting with the coarsest
// abstract replica groups. If it does, add to the conflicting collectives
// and update the supremum.
AbstractReplicaGroups groups;
GetAbstractReplicaGroups(instr, groups);
if (IsConflictingAbstractReplicaGroups(
groups, abstract_replica_groups_supremum)) {
conflicing_collectives.push_back(instr);
GetAbstractReplicaGroups(instr, abstract_replica_groups_supremum);
seen.insert(instr);
fixpoint_reached = false;
}
}
} while (!fixpoint_reached);

return conflicing_collectives;
}

static std::vector<HloInstruction*> FindAllConflictingCollectives(
HloComputation* computation,
const std::vector<HloCollectivePermuteInstruction*>& cps) {
std::vector<HloInstruction*> seed_collectives;
for (HloCollectivePermuteInstruction* cp : cps) {
seed_collectives.push_back(static_cast<HloInstruction*>(cp));
}
return FindAllConflictingCollectives(computation, seed_collectives);
}

// Inserts control dependencies to enforce send/recv chain order.
// The order protects from a potential deadlock when every device tries to
// execute recv with no devices executing send - if there are no constraints,
// the scheduler is free to schedule all recv ops first.
// deco_post_order is expected to be post order within a computation.
// TODO b/388072780 add second hueristic to enforce back edge before the forward
// edge for max performance.
// TODO(b/392684119): Also add control dependencies to conflicting collectives
// other than send/recv.
absl::Status EnforceOrderOfSendRecvChains(
static absl::Status EnforceOrderOfSendRecvChain(
std::vector<DecomposedCp>& deco_post_order) {
for (size_t i = 1; i < deco_post_order.size(); ++i) {
DecomposedCp& cur = deco_post_order[i];
DecomposedCp& prev = deco_post_order[i - 1];
TF_RETURN_IF_ERROR(prev.send->AddControlDependencyTo(cur.recv));
TF_RETURN_IF_ERROR(prev.send_done->AddControlDependencyTo(cur.recv_done));
}
return absl::OkStatus();
}

} // namespace
static absl::Status EnforceOrderOfSendRecvChainRelativeToConflictingCollectives(
std::vector<DecomposedCp>& deco_post_order,
std::vector<HloInstruction*> conflicting_collectives) {
// Find last collective in send/recv chain.
if (deco_post_order.empty()) return absl::OkStatus();
HloInstruction* last_in_chain = deco_post_order.back().send_done;

// Add control dependencies from chain to all conflicting collectives.
for (HloInstruction* instr : conflicting_collectives) {
TF_RETURN_IF_ERROR(last_in_chain->AddControlDependencyTo(instr));
}

return absl::OkStatus();
}

absl::StatusOr<bool> CollectivePermuteDecomposer::Run(
HloModule* module,
Expand Down Expand Up @@ -262,6 +442,7 @@ absl::StatusOr<bool> CollectivePermuteDecomposer::Run(
while_bodies.insert(instr->while_body());
continue;
}

if (instr->opcode() != HloOpcode::kCollectivePermute) {
continue;
}
Expand Down Expand Up @@ -303,6 +484,13 @@ absl::StatusOr<bool> CollectivePermuteDecomposer::Run(
}
} // for MakeInstructionPostOrder

// Find all collectives conflicting with the collective permutes that we
// want to decompose. This is needed to add control dependencies to these
// conflicting collectives so that they cannot move in between the
// decomposed send/recv, which would lead to deadlocks.
std::vector<HloInstruction*> conflicing_collectives =
FindAllConflictingCollectives(computation, cps_to_decompose);

// cps to decompose were collected post order, similarly we will collect
// the decomposed send/recv pairs.
std::vector<DecomposedCp> deco_post_order;
Expand All @@ -321,7 +509,16 @@ absl::StatusOr<bool> CollectivePermuteDecomposer::Run(
DecomposeCollectivePermute(cp, computation, pipeline_decision));
deco_post_order.push_back(decomposed_ops);
}
TF_RETURN_IF_ERROR(EnforceOrderOfSendRecvChains(deco_post_order));

// Enforce order of send/recv pairs at the beginning of the loop body. Also
// enforce all other conflicting collectives to follow the send/recv chain
// so that these cannot be scheduled in between the send/recv, which would
// also lead to deadlocks.
TF_RETURN_IF_ERROR(EnforceOrderOfSendRecvChain(deco_post_order));
TF_RETURN_IF_ERROR(
EnforceOrderOfSendRecvChainRelativeToConflictingCollectives(
deco_post_order, conflicing_collectives));

if (!cps_to_decompose.empty()) {
changed = true;
}
Expand Down
Loading

0 comments on commit 112f7e8

Please sign in to comment.