Skip to content

Commit

Permalink
[XLA:LatencyHidingScheduler] Let GetResourcesFromInstruction return…
Browse files Browse the repository at this point in the history
… a complete list of resources used by instructions in a while loop. This will make async `done` and while ops have similar priority (in terms of occupying resource types) and avoid delaying the while loops only because they cross the overlap limit.

This CL also fixes the the double counting of a resource in `GetNumResourcesPerInstruction` because of multiple async `done` ops in the while body.

PiperOrigin-RevId: 721121051
  • Loading branch information
seherellis authored and Google-ML-Automation committed Jan 30, 2025
1 parent f95ad4d commit c8675f8
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
37 changes: 37 additions & 0 deletions xla/service/latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,34 @@ ResourcesVector AsyncTracker::GetResourcesFromInstructionImpl(
: std::make_pair(ResourceTypeToIndex(ResourceType::kSendRecv),
ResourceUsageType::kResourceOccupy)};
}
case HloOpcode::kWhile: {
ResourcesVector result;
absl::flat_hash_set<int64_t> seen_occupied_resources;
absl::flat_hash_set<int64_t> seen_released_resources;
absl::flat_hash_set<int64_t> seen_no_resource;
for (const HloInstruction* instr : hlo.while_body()->instructions()) {
ResourcesVector rv = GetResourcesFromInstructionImpl(*instr);
if (rv.empty()) {
continue;
}
for (const auto& [resource, usage] : rv) {
if (usage == ResourceUsageType::kResourceOccupy &&
!seen_occupied_resources.contains(resource)) {
seen_occupied_resources.insert(resource);
result.push_back(std::make_pair(resource, usage));
} else if (usage == ResourceUsageType::kResourceRelease &&
!seen_released_resources.contains(resource)) {
seen_released_resources.insert(resource);
result.push_back(std::make_pair(resource, usage));
} else if (usage == ResourceUsageType::kNoResource &&
!seen_no_resource.contains(resource)) {
seen_no_resource.insert(resource);
result.push_back(std::make_pair(resource, usage));
}
}
}
return result;
}
default:
return ResourcesVector{};
}
Expand Down Expand Up @@ -370,10 +398,19 @@ AsyncTracker::RecursivelyComputeResourceMap(
}
per_opcode_map = std::make_unique<absl::flat_hash_map<int64_t, int64_t>>();
auto* m = per_opcode_map.get();
absl::flat_hash_set<int64_t> seen_resources_per_comp;
for (HloInstruction* instr : computation->instructions()) {
absl::flat_hash_set<int64_t> seen_resources_per_inst;
if (IsSupportedAsyncDone(*instr)) {
for (const auto& resource : GetResourcesFromInstruction(*instr)) {
if (seen_resources_per_comp.contains(resource.first)) {
continue;
}
++(*m)[resource.first];
seen_resources_per_inst.insert(resource.first);
}
for (const auto& resource : seen_resources_per_inst) {
seen_resources_per_comp.insert(resource);
}
}
for (const HloComputation* called_comp : instr->called_computations()) {
Expand Down
66 changes: 66 additions & 0 deletions xla/service/latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3970,4 +3970,70 @@ TEST_F(LatencyHidingSchedulerTest, RaggedAllToAll) {
GetIndex(new_instruction_sequence, "ra2a-done"));
}

TEST_F(LatencyHidingSchedulerTest, WhileWithCompleteResourceList) {
constexpr absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
while_cond {
param = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) parameter(0)
ROOT gte = pred[] get-tuple-element(param), index=2
}
while_body {
param = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) parameter(0)
gte0 = f32[16,64,256]{2,1,0} get-tuple-element(param), index=0
gte1 = f32[16,64,256]{2,1,0} get-tuple-element(param), index=1
gte2 = pred[] get-tuple-element(param), index=2
cps0 = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, u32[], u32[]) collective-permute-start(gte1), source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
cpd0 = f32[16,64,256]{2,1,0} collective-permute-done(cps0)
c = f32[16,256,256]{2,1,0} convolution(gte0, gte0), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb
slice = f32[16,64,256]{2,1,0} slice(c), slice={[0:16], [0:64], [0:256]}
add = f32[16,64,256]{2,1,0} add(gte0, slice)
ROOT tuple = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) tuple(add, cpd0, gte2)
}
ENTRY entry {
p0 = f32[64,1024]{1,0} parameter(0)
p1 = f32[16,64,256]{2,1,0} parameter(1)
p2 = f32[16,64,256]{2,1,0} parameter(2)
p3 = pred[] parameter(3)
cps1 = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}) collective-permute-start(p1), source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
cpd1 = f32[16,64,256]{2,1,0} collective-permute-done(cps1)
cps2 = (f32[64,1024]{1,0}, f32[64,1024]{1,0}) collective-permute-start(p0), source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
tuple = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) tuple(cpd1, p2, p3)
while = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) while(tuple), condition=while_cond, body=while_body
cpd2 = f32[64,1024]{1,0} collective-permute-done(cps2)
gte = f32[16,64,256]{2,1,0} get-tuple-element(while), index=0
ROOT tuple1 = (f32[16,64,256]{2,1,0}, f32[64,1024]{1,0}) tuple(gte, cpd2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string));
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
sched_config.aggressive_scheduling_policies = true;
sched_config.collective_permute_overlap_limit = 1;
EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config,
std::make_unique<TestLatencyEstimator>())
.ok());
EXPECT_TRUE(hlo_module->has_entry_computation());

std::vector<HloInstruction*> new_instruction_sequence =
module_schedule.sequence(hlo_module->entry_computation()).instructions();
if (VLOG_IS_ON(1)) {
for (auto* new_i : new_instruction_sequence) {
VLOG(1) << new_i->ToString();
}
}
// Without proper resources assigned to while, cpd2 would be prioritized (and
// hence scheduled after while) even though while has a higher async depth.
// With the complete resources assigned to while, it has a similar priority as
// cpd2 in terms of the kScheduleDone rule, so we let the kAsyncDepth rule to
// prioritize scheduling while. This prevents the needless delaying of blocker
// while ops and hence helps reducing the live ranges of their data-dependent
// instructions.
EXPECT_LT(GetIndex(new_instruction_sequence, "cpd2"),
GetIndex(new_instruction_sequence, "while"));
}

} // namespace xla

0 comments on commit c8675f8

Please sign in to comment.