Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:LatencyHidingScheduler] Let GetResourcesFromInstruction return a complete list of resources used by instructions in a while loop. This will make async done and while ops have similar priority (in terms of occupying resource types) and avoid delaying the while loops only because they cross the overlap limit (even though they have a higher async depth). #22119

Merged
merged 1 commit into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions xla/service/latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,34 @@ ResourcesVector AsyncTracker::GetResourcesFromInstructionImpl(
: std::make_pair(ResourceTypeToIndex(ResourceType::kSendRecv),
ResourceUsageType::kResourceOccupy)};
}
case HloOpcode::kWhile: {
ResourcesVector result;
absl::flat_hash_set<int64_t> seen_occupied_resources;
absl::flat_hash_set<int64_t> seen_released_resources;
absl::flat_hash_set<int64_t> seen_no_resource;
for (const HloInstruction* instr : hlo.while_body()->instructions()) {
ResourcesVector rv = GetResourcesFromInstructionImpl(*instr);
if (rv.empty()) {
continue;
}
for (const auto& [resource, usage] : rv) {
if (usage == ResourceUsageType::kResourceOccupy &&
!seen_occupied_resources.contains(resource)) {
seen_occupied_resources.insert(resource);
result.push_back(std::make_pair(resource, usage));
} else if (usage == ResourceUsageType::kResourceRelease &&
!seen_released_resources.contains(resource)) {
seen_released_resources.insert(resource);
result.push_back(std::make_pair(resource, usage));
} else if (usage == ResourceUsageType::kNoResource &&
!seen_no_resource.contains(resource)) {
seen_no_resource.insert(resource);
result.push_back(std::make_pair(resource, usage));
}
}
}
return result;
}
default:
return ResourcesVector{};
}
Expand All @@ -361,6 +389,17 @@ int64_t AsyncTracker::GetNumResourcesPerInstruction(
instr);
}

// Returns the number of "occupy" type of resources used by the instructions in
// the given computation. If there are multiple instructions in the computation
// that have the exact same resource usages, it only counts one of them. For
// example, if there are two async all-gathers in a while loop, this will have 1
// for all-gather in the returned map for the while instruction. This is because
// there is no proof that those all-gathers will overlap each other and over-
// counting such resources causes the while not being scheduled due to the
// resource limits (checked in scheduling_node_crosses_overlap_limit).
//
// If an instruction uses multiple instances of the same "occupy" type of
// resource, that number is respected and returned in the resulting map.
const absl::flat_hash_map<int64_t, int64_t>&
AsyncTracker::RecursivelyComputeResourceMap(
const HloComputation* computation) const {
Expand All @@ -370,18 +409,30 @@ AsyncTracker::RecursivelyComputeResourceMap(
}
per_opcode_map = std::make_unique<absl::flat_hash_map<int64_t, int64_t>>();
auto* m = per_opcode_map.get();
absl::flat_hash_set<int64_t> seen_resources_per_comp;
for (HloInstruction* instr : computation->instructions()) {
absl::flat_hash_set<int64_t> seen_resources_per_inst;
if (IsSupportedAsyncDone(*instr)) {
for (const auto& resource : GetResourcesFromInstruction(*instr)) {
if (seen_resources_per_comp.contains(resource.first)) {
continue;
}
++(*m)[resource.first];
seen_resources_per_inst.insert(resource.first);
}
}
for (const HloComputation* called_comp : instr->called_computations()) {
for (auto& called_per_opcode_pair :
RecursivelyComputeResourceMap(called_comp)) {
if (seen_resources_per_comp.contains(called_per_opcode_pair.first)) {
continue;
}
(*m)[called_per_opcode_pair.first] += called_per_opcode_pair.second;
seen_resources_per_inst.insert(called_per_opcode_pair.first);
}
}
seen_resources_per_comp.insert(seen_resources_per_inst.begin(),
seen_resources_per_inst.end());
}
return *m;
}
Expand All @@ -406,6 +457,9 @@ int64_t AsyncTracker::GetNumResourcesPerInstruction(
auto opcode_it = map.find(resource_type);
if (opcode_it != map.end()) {
num_resources += opcode_it->second;
// We can return early if we have found the resource we are looking for.
// There is no need to check each called computation.
break;
}
}
return num_resources;
Expand Down Expand Up @@ -1829,6 +1883,9 @@ absl::StatusOr<HloGraphNode::TimeCost> DefaultSchedulerCore::ScheduleNode(
for (HloEdge& edge : n->GetPredecessors()) {
const int64_t current_outdegree = edge.Target().GetOutdegree();
// Node is not ready yet. Decrease the outdegree and continue.
if (n->GetInstr().name() == "gte1.1")
LOG(INFO) << "SEHER edge: " << edge.Target().GetInstr().name()
<< " outdegree: " << current_outdegree;
if (current_outdegree != 1) {
edge.Target().SetOutdegree(current_outdegree - 1);
continue;
Expand Down
83 changes: 76 additions & 7 deletions xla/service/latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2138,7 +2138,7 @@ ENTRY entry {
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
sched_config.collective_permute_overlap_limit = 2;
sched_config.collective_permute_overlap_limit = 1;
TF_EXPECT_OK(RunScheduler(hlo_module.get(), sched_config));
EXPECT_TRUE(hlo_module->has_entry_computation());

Expand Down Expand Up @@ -2210,7 +2210,7 @@ ENTRY entry {
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
sched_config.collective_permute_overlap_limit = 2;
sched_config.collective_permute_overlap_limit = 1;
TF_EXPECT_OK(RunScheduler(hlo_module.get(), sched_config));
EXPECT_TRUE(hlo_module->has_entry_computation());

Expand All @@ -2222,8 +2222,8 @@ ENTRY entry {
}
}

// Do not overlap if the sum of collectives inside the loop + the collective
// we are trying to overlap would go beyond the overlap limit.
// Since there is at least one collective permute in the while op, overlapping
// it with the outer collective permute is not possible for the limit of 1.
EXPECT_GT(GetIndex(new_instruction_sequence, "collective-permute-start.2"),
GetIndex(new_instruction_sequence, "while"));
}
Expand Down Expand Up @@ -2267,7 +2267,7 @@ ENTRY entry {
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
sched_config.collective_permute_overlap_limit = 3;
sched_config.collective_permute_overlap_limit = 2;
TF_EXPECT_OK(RunScheduler(hlo_module.get(), sched_config));
EXPECT_TRUE(hlo_module->has_entry_computation());

Expand All @@ -2279,8 +2279,11 @@ ENTRY entry {
}
}

// Do not overlap if the sum of collectives inside the loop + the collective
// we are trying to overlap would go beyond the overlap limit.
// This is optimistic in the sense that the inner collective permute is not
// going to overlap the while in the same computation, so the maximum number
// of concurrent collective permutes inside while is still 1. Therefore, we
// can overlap the outer collective permute with the while op because we are
// still obeying the limit of 2.
EXPECT_LT(GetIndex(new_instruction_sequence, "collective-permute-start.2"),
GetIndex(new_instruction_sequence, "while"));
}
Expand Down Expand Up @@ -4036,4 +4039,70 @@ TEST_F(LatencyHidingSchedulerTest, InvalidAnnotationOverlap) {
GetIndex(new_instruction_sequence, "agd0")));
}

TEST_F(LatencyHidingSchedulerTest, WhileWithCompleteResourceList) {
constexpr absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
while_cond {
param = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) parameter(0)
ROOT gte = pred[] get-tuple-element(param), index=2
}
while_body {
param = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) parameter(0)
gte0 = f32[16,64,256]{2,1,0} get-tuple-element(param), index=0
gte1 = f32[16,64,256]{2,1,0} get-tuple-element(param), index=1
gte2 = pred[] get-tuple-element(param), index=2
cps0 = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, u32[], u32[]) collective-permute-start(gte1), source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
cpd0 = f32[16,64,256]{2,1,0} collective-permute-done(cps0)
c = f32[16,256,256]{2,1,0} convolution(gte0, gte0), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb
slice = f32[16,64,256]{2,1,0} slice(c), slice={[0:16], [0:64], [0:256]}
add = f32[16,64,256]{2,1,0} add(gte0, slice)
ROOT tuple = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) tuple(add, cpd0, gte2)
}
ENTRY entry {
p0 = f32[64,1024]{1,0} parameter(0)
p1 = f32[16,64,256]{2,1,0} parameter(1)
p2 = f32[16,64,256]{2,1,0} parameter(2)
p3 = pred[] parameter(3)
cps1 = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}) collective-permute-start(p1), source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
cpd1 = f32[16,64,256]{2,1,0} collective-permute-done(cps1)
cps2 = (f32[64,1024]{1,0}, f32[64,1024]{1,0}) collective-permute-start(p0), source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
tuple = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) tuple(cpd1, p2, p3)
while = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) while(tuple), condition=while_cond, body=while_body
cpd2 = f32[64,1024]{1,0} collective-permute-done(cps2)
gte = f32[16,64,256]{2,1,0} get-tuple-element(while), index=0
ROOT tuple1 = (f32[16,64,256]{2,1,0}, f32[64,1024]{1,0}) tuple(gte, cpd2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string));
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
sched_config.aggressive_scheduling_policies = true;
sched_config.collective_permute_overlap_limit = 1;
EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config,
std::make_unique<TestLatencyEstimator>())
.ok());
EXPECT_TRUE(hlo_module->has_entry_computation());

std::vector<HloInstruction*> new_instruction_sequence =
module_schedule.sequence(hlo_module->entry_computation()).instructions();
if (VLOG_IS_ON(1)) {
for (auto* new_i : new_instruction_sequence) {
VLOG(1) << new_i->ToString();
}
}
// Without proper resources assigned to while, cpd2 would be prioritized (and
// hence scheduled after while) even though while has a higher async depth.
// With the complete resources assigned to while, it has a similar priority as
// cpd2 in terms of the kScheduleDone rule, so we let the kAsyncDepth rule to
// prioritize scheduling while. This prevents the needless delaying of blocker
// while ops and hence helps reducing the live ranges of their data-dependent
// instructions.
EXPECT_LT(GetIndex(new_instruction_sequence, "cpd2"),
GetIndex(new_instruction_sequence, "while"));
}

} // namespace xla
Loading