Skip to content

Commit

Permalink
[XLA] Allow synchronously invoking computations on separate threads
Browse files Browse the repository at this point in the history
If we allow asynchronously invoking a computation on a separate thread by wrapping the kCall op in a async-{start,done} pair, then we should also allow a synchronous version of kCall to do the same. Same logic applies kCustomCall.

PiperOrigin-RevId: 724086062
  • Loading branch information
vsytch authored and Google-ML-Automation committed Feb 7, 2025
1 parent d836134 commit 1606e5e
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 14 deletions.
4 changes: 4 additions & 0 deletions xla/hlo/analysis/hlo_dataflow_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,10 @@ bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {

bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
CHECK_EQ(call->opcode(), HloOpcode::kCall);
if (!HloInstruction::IsThreadIncluded(call->to_apply()->execution_thread(),
execution_threads_)) {
return false;
}
InstructionValueSet& value_set = GetInstructionValueSet(call);
InstructionValueSet& root_value_set =
GetInstructionValueSet(call->to_apply()->root_instruction());
Expand Down
13 changes: 8 additions & 5 deletions xla/service/hlo_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2868,9 +2868,11 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
}

absl::Status HandleCall(HloInstruction* call) override {
// Allow kCall to contain computations on separate thread.
return CheckCallableInstructionThreadName(
call, /*skip_nested_async_op_check=*/true);
if (opts_.verify_call_nested_computation_thread_name) {
return CheckCallableInstructionThreadName(
call, /*skip_nested_async_op_check=*/true);
}
return absl::OkStatus();
}

absl::Status HandleConditional(HloInstruction* conditional) override {
Expand Down Expand Up @@ -2951,7 +2953,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
}

absl::Status HandleCustomCall(HloInstruction* hlo) override {
if (opts_.verify_custom_call_nested_computation_thread_name) {
if (opts_.verify_call_nested_computation_thread_name) {
// Allow kCustomCall to contain computations on separate thread.
return CheckCallableInstructionThreadName(
hlo, /*skip_nested_async_op_check=*/true);
Expand Down Expand Up @@ -2993,7 +2995,8 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
}
}

if (instruction->has_to_apply() &&
if (opts_.verify_call_nested_computation_thread_name &&
instruction->has_to_apply() &&
instruction->to_apply()->execution_thread() !=
instruction->parent()->execution_thread()) {
return Internal(
Expand Down
9 changes: 4 additions & 5 deletions xla/service/hlo_verifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ struct HloVerifierOpts {
return std::move(*this);
}

HloVerifierOpts&& VerifyCustomCallNestedComputationThreadName() {
verify_custom_call_nested_computation_thread_name = true;
HloVerifierOpts&& VerifyCallNestedComputationThreadName() {
verify_call_nested_computation_thread_name = true;
return std::move(*this);
}

Expand Down Expand Up @@ -137,9 +137,8 @@ struct HloVerifierOpts {
// Check that reshape is a physical bitcast.
bool verify_reshape_is_bitcast = false;

// Check that custom call's called computations have same thread name as
// parent computation.
bool verify_custom_call_nested_computation_thread_name = true;
// Check that called computations have same thread name as parent computation.
bool verify_call_nested_computation_thread_name = false;

// Check device numbers in sharding verification.
bool verify_sharding_device_numbers = true;
Expand Down
16 changes: 12 additions & 4 deletions xla/service/hlo_verifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,10 @@ TEST_F(HloVerifierTest, CheckCallThreadMismatch) {
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));

auto status = verifier().Run(module.get()).status();
auto status =
HloVerifier{HloVerifierOpts{}.VerifyCallNestedComputationThreadName()}
.Run(module.get())
.status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
HasSubstr("mycall top_apply computation execution thread does "
Expand Down Expand Up @@ -2263,8 +2266,14 @@ TEST_F(HloVerifierTest, FusionNestedComputationThreadVerifier) {
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));

auto status =
HloVerifier{HloVerifierOpts{}.VerifyCallNestedComputationThreadName()}
.Run(module.get())
.status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
verifier().Run(module.get()).status().message(),
status.message(),
HasSubstr("crs0 top_apply computation execution thread does not match "
"(parallel_thread vs main)"));
}
Expand Down Expand Up @@ -3003,8 +3012,7 @@ TEST_F(HloVerifierTest, VerifyCustomCallThread) {

TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
auto status =
HloVerifier{
HloVerifierOpts{}.VerifyCustomCallNestedComputationThreadName()}
HloVerifier{HloVerifierOpts{}.VerifyCallNestedComputationThreadName()}
.Run(module.get())
.status();
ASSERT_FALSE(status.ok());
Expand Down

0 comments on commit 1606e5e

Please sign in to comment.