Skip to content

Commit

Permalink
[XLA:LatencyHidingScheduler] Let GetResourcesFromInstruction return…
Browse files Browse the repository at this point in the history
… a complete list of resources used by instructions in a while loop. This will make async `done` and while ops have similar priority (in terms of occupying resource types) and avoid delaying the while loops only because they cross the overlap limit (even though they have a higher async depth).

This CL also fixes the double counting of a resource in `GetNumResourcesPerInstruction` because of multiple async `done` ops in the while body.

PiperOrigin-RevId: 721121051
  • Loading branch information
seherellis authored and Google-ML-Automation committed Feb 7, 2025
1 parent 5783c91 commit 2a69876
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 7 deletions.
57 changes: 57 additions & 0 deletions xla/service/latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,34 @@ ResourcesVector AsyncTracker::GetResourcesFromInstructionImpl(
: std::make_pair(ResourceTypeToIndex(ResourceType::kSendRecv),
ResourceUsageType::kResourceOccupy)};
}
case HloOpcode::kWhile: {
ResourcesVector result;
absl::flat_hash_set<int64_t> seen_occupied_resources;
absl::flat_hash_set<int64_t> seen_released_resources;
absl::flat_hash_set<int64_t> seen_no_resource;
for (const HloInstruction* instr : hlo.while_body()->instructions()) {
ResourcesVector rv = GetResourcesFromInstructionImpl(*instr);
if (rv.empty()) {
continue;
}
for (const auto& [resource, usage] : rv) {
if (usage == ResourceUsageType::kResourceOccupy &&
!seen_occupied_resources.contains(resource)) {
seen_occupied_resources.insert(resource);
result.push_back(std::make_pair(resource, usage));
} else if (usage == ResourceUsageType::kResourceRelease &&
!seen_released_resources.contains(resource)) {
seen_released_resources.insert(resource);
result.push_back(std::make_pair(resource, usage));
} else if (usage == ResourceUsageType::kNoResource &&
!seen_no_resource.contains(resource)) {
seen_no_resource.insert(resource);
result.push_back(std::make_pair(resource, usage));
}
}
}
return result;
}
default:
return ResourcesVector{};
}
Expand All @@ -361,6 +389,17 @@ int64_t AsyncTracker::GetNumResourcesPerInstruction(
instr);
}

// Returns the number of "occupy" type of resources used by the instructions in
// the given computation. If there are multiple instructions in the computation
// that have the exact same resource usages, it only counts one of them. For
// example, if there are two async all-gathers in a while loop, this will have 1
// for all-gather in the returned map for the while instruction. This is because
// there is no proof that those all-gathers will overlap each other and over-
// counting such resources causes the while not being scheduled due to the
// resource limits (checked in scheduling_node_crosses_overlap_limit).
//
// If an instruction uses multiple instances of the same "occupy" type of
// resource, that number is respected and returned in the resulting map.
const absl::flat_hash_map<int64_t, int64_t>&
AsyncTracker::RecursivelyComputeResourceMap(
const HloComputation* computation) const {
Expand All @@ -370,18 +409,30 @@ AsyncTracker::RecursivelyComputeResourceMap(
}
per_opcode_map = std::make_unique<absl::flat_hash_map<int64_t, int64_t>>();
auto* m = per_opcode_map.get();
absl::flat_hash_set<int64_t> seen_resources_per_comp;
for (HloInstruction* instr : computation->instructions()) {
absl::flat_hash_set<int64_t> seen_resources_per_inst;
if (IsSupportedAsyncDone(*instr)) {
for (const auto& resource : GetResourcesFromInstruction(*instr)) {
if (seen_resources_per_comp.contains(resource.first)) {
continue;
}
++(*m)[resource.first];
seen_resources_per_inst.insert(resource.first);
}
}
for (const HloComputation* called_comp : instr->called_computations()) {
for (auto& called_per_opcode_pair :
RecursivelyComputeResourceMap(called_comp)) {
if (seen_resources_per_comp.contains(called_per_opcode_pair.first)) {
continue;
}
(*m)[called_per_opcode_pair.first] += called_per_opcode_pair.second;
seen_resources_per_inst.insert(called_per_opcode_pair.first);
}
}
seen_resources_per_comp.insert(seen_resources_per_inst.begin(),
seen_resources_per_inst.end());
}
return *m;
}
Expand All @@ -406,6 +457,9 @@ int64_t AsyncTracker::GetNumResourcesPerInstruction(
auto opcode_it = map.find(resource_type);
if (opcode_it != map.end()) {
num_resources += opcode_it->second;
// We can return early if we have found the resource we are looking for.
// There is no need to check each called computation.
break;
}
}
return num_resources;
Expand Down Expand Up @@ -1829,6 +1883,9 @@ absl::StatusOr<HloGraphNode::TimeCost> DefaultSchedulerCore::ScheduleNode(
for (HloEdge& edge : n->GetPredecessors()) {
const int64_t current_outdegree = edge.Target().GetOutdegree();
// Node is not ready yet. Decrease the outdegree and continue.
if (n->GetInstr().name() == "gte1.1")
LOG(INFO) << "SEHER edge: " << edge.Target().GetInstr().name()
<< " outdegree: " << current_outdegree;
if (current_outdegree != 1) {
edge.Target().SetOutdegree(current_outdegree - 1);
continue;
Expand Down
83 changes: 76 additions & 7 deletions xla/service/latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2138,7 +2138,7 @@ ENTRY entry {
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
sched_config.collective_permute_overlap_limit = 2;
sched_config.collective_permute_overlap_limit = 1;
TF_EXPECT_OK(RunScheduler(hlo_module.get(), sched_config));
EXPECT_TRUE(hlo_module->has_entry_computation());

Expand Down Expand Up @@ -2210,7 +2210,7 @@ ENTRY entry {
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
sched_config.collective_permute_overlap_limit = 2;
sched_config.collective_permute_overlap_limit = 1;
TF_EXPECT_OK(RunScheduler(hlo_module.get(), sched_config));
EXPECT_TRUE(hlo_module->has_entry_computation());

Expand All @@ -2222,8 +2222,8 @@ ENTRY entry {
}
}

// Do not overlap if the sum of collectives inside the loop + the collective
// we are trying to overlap would go beyond the overlap limit.
// Since there is at least one collective permute in the while op, overlapping
// it with the outer collective permute is not possible for the limit of 1.
EXPECT_GT(GetIndex(new_instruction_sequence, "collective-permute-start.2"),
GetIndex(new_instruction_sequence, "while"));
}
Expand Down Expand Up @@ -2267,7 +2267,7 @@ ENTRY entry {
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
sched_config.collective_permute_overlap_limit = 3;
sched_config.collective_permute_overlap_limit = 2;
TF_EXPECT_OK(RunScheduler(hlo_module.get(), sched_config));
EXPECT_TRUE(hlo_module->has_entry_computation());

Expand All @@ -2279,8 +2279,11 @@ ENTRY entry {
}
}

// Do not overlap if the sum of collectives inside the loop + the collective
// we are trying to overlap would go beyond the overlap limit.
// This is optimistic in the sense that the inner collective permute is not
// going to overlap the while in the same computation, so the maximum number
// of concurrent collective permutes inside while is still 1. Therefore, we
// can overlap the outer collective permute with the while op because we are
// still obeying the limit of 2.
EXPECT_LT(GetIndex(new_instruction_sequence, "collective-permute-start.2"),
GetIndex(new_instruction_sequence, "while"));
}
Expand Down Expand Up @@ -4036,4 +4039,70 @@ TEST_F(LatencyHidingSchedulerTest, InvalidAnnotationOverlap) {
GetIndex(new_instruction_sequence, "agd0")));
}

TEST_F(LatencyHidingSchedulerTest, WhileWithCompleteResourceList) {
constexpr absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
while_cond {
param = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) parameter(0)
ROOT gte = pred[] get-tuple-element(param), index=2
}
while_body {
param = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) parameter(0)
gte0 = f32[16,64,256]{2,1,0} get-tuple-element(param), index=0
gte1 = f32[16,64,256]{2,1,0} get-tuple-element(param), index=1
gte2 = pred[] get-tuple-element(param), index=2
cps0 = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, u32[], u32[]) collective-permute-start(gte1), source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
cpd0 = f32[16,64,256]{2,1,0} collective-permute-done(cps0)
c = f32[16,256,256]{2,1,0} convolution(gte0, gte0), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb
slice = f32[16,64,256]{2,1,0} slice(c), slice={[0:16], [0:64], [0:256]}
add = f32[16,64,256]{2,1,0} add(gte0, slice)
ROOT tuple = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) tuple(add, cpd0, gte2)
}
ENTRY entry {
p0 = f32[64,1024]{1,0} parameter(0)
p1 = f32[16,64,256]{2,1,0} parameter(1)
p2 = f32[16,64,256]{2,1,0} parameter(2)
p3 = pred[] parameter(3)
cps1 = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}) collective-permute-start(p1), source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
cpd1 = f32[16,64,256]{2,1,0} collective-permute-done(cps1)
cps2 = (f32[64,1024]{1,0}, f32[64,1024]{1,0}) collective-permute-start(p0), source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
tuple = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) tuple(cpd1, p2, p3)
while = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) while(tuple), condition=while_cond, body=while_body
cpd2 = f32[64,1024]{1,0} collective-permute-done(cps2)
gte = f32[16,64,256]{2,1,0} get-tuple-element(while), index=0
ROOT tuple1 = (f32[16,64,256]{2,1,0}, f32[64,1024]{1,0}) tuple(gte, cpd2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string));
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
sched_config.aggressive_scheduling_policies = true;
sched_config.collective_permute_overlap_limit = 1;
EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config,
std::make_unique<TestLatencyEstimator>())
.ok());
EXPECT_TRUE(hlo_module->has_entry_computation());

std::vector<HloInstruction*> new_instruction_sequence =
module_schedule.sequence(hlo_module->entry_computation()).instructions();
if (VLOG_IS_ON(1)) {
for (auto* new_i : new_instruction_sequence) {
VLOG(1) << new_i->ToString();
}
}
// Without proper resources assigned to while, cpd2 would be prioritized (and
// hence scheduled after while) even though while has a higher async depth.
// With the complete resources assigned to while, it has a similar priority as
// cpd2 in terms of the kScheduleDone rule, so we let the kAsyncDepth rule to
// prioritize scheduling while. This prevents the needless delaying of blocker
// while ops and hence helps reducing the live ranges of their data-dependent
// instructions.
EXPECT_LT(GetIndex(new_instruction_sequence, "cpd2"),
GetIndex(new_instruction_sequence, "while"));
}

} // namespace xla

0 comments on commit 2a69876

Please sign in to comment.