Skip to content

Commit

Permalink
[MultiHostHloRunner] Fix the scope for GPURunnerProfiler
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 724123157
  • Loading branch information
juliagmt-google authored and Google-ML-Automation committed Feb 14, 2025
1 parent adc5799 commit 519055d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 27 deletions.
2 changes: 2 additions & 0 deletions xla/tools/multihost_hlo_runner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ cc_library(
"@tsl//tsl/platform:statusor",
] + if_cuda_or_rocm([
"//xla/service:gpu_plugin",
"//xla/backends/profiler/gpu:cupti_tracer",
"//xla/backends/profiler/gpu:device_tracer",
]) + if_cuda([
"//xla/stream_executor:cuda_platform",
] + if_google(
Expand Down
40 changes: 15 additions & 25 deletions xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,39 +168,29 @@ TEST_F(FunctionalHloRunnerTest, GPUProfilerKeepXSpaceReturnsNonNullXSpace) {

TEST_F(FunctionalHloRunnerTest,
SingleDeviceHloWithGPUProfilerSavesXSpaceToDisk) {
if (IsTestingCpu()) {
GTEST_SKIP() << "GPU-only test";
}

GpuClientOptions gpu_options;
gpu_options.node_id = 0;
gpu_options.num_nodes = 16;
gpu_options.enable_mock_nccl = true;

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
GetPjRtClient());
std::string profile_dump_path =
tsl::io::JoinPath(testing::TempDir(), "xspace.pb");
tsl::Env* env = tsl::Env::Default();
tsl::FileSystem* fs = nullptr;
TF_ASSERT_OK(env->GetFileSystemForFile(profile_dump_path, &fs));

FunctionalHloRunner::RawCompileOptions raw_compile_options;
raw_compile_options.xla_gpu_dump_xspace_to = profile_dump_path;

TF_ASSERT_OK_AND_ASSIGN(
xla::PjRtEnvironment pjrt_env,
GetPjRtEnvironmentForGpu("", gpu_options, absl::Seconds(120)));
std::unique_ptr<GPURunnerProfiler> profiler;
FunctionalHloRunner::RunningOptions running_options;
TF_ASSERT_OK_AND_ASSIGN(
auto profiler,
GPURunnerProfiler::Create(profile_dump_path, /*keep_xspace=*/false));
profiler,
GPURunnerProfiler::Create(profile_dump_path, /*keep_xspace=*/true));
running_options.profiler = profiler.get();

running_options.num_repeats = 2;
TF_EXPECT_OK(FunctionalHloRunner::LoadAndRunAndDump(
*pjrt_env.client,
/* debug_options= */ {}, /* preproc_options= */ {}, raw_compile_options,
running_options, {GetHloPath("single_device.hlo")}, InputFormat::kText));
EXPECT_EQ(profiler->GetXSpace(), nullptr);
TF_EXPECT_OK(env->FileExists(profile_dump_path));
*client,
/* debug_options= */ {}, /* preproc_options= */ {},
/* raw_compile_options = */ {}, running_options,
{GetHloPath("single_device.hlo")}, InputFormat::kText));

if (client->platform_name() == "cuda") {
EXPECT_NE(profiler->GetXSpace(), nullptr);
EXPECT_GT(profiler->GetXSpace()->planes_size(), 0);
}
}

TEST_F(FunctionalHloRunnerTest, Sharded2Devices) {
Expand Down
5 changes: 3 additions & 2 deletions xla/tools/multihost_hlo_runner/hlo_runner_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ static absl::Status RunMultihostHloRunner(int argc, char** argv,
QCHECK_LT(opts.gpu_client_mem_fraction, 1.0);

PjRtEnvironment env;
std::unique_ptr<GPURunnerProfiler> gpu_runner_profiler;
if (opts.device_type_str == "gpu") {
xla::GpuClientOptions gpu_options;
gpu_options.node_id = opts.task_id;
Expand All @@ -248,10 +249,10 @@ static absl::Status RunMultihostHloRunner(int argc, char** argv,
// Create a GPURunnerProfiler to profile GPU executions to save xspace data
// to disk.
if (env.client != nullptr && !opts.xla_gpu_dump_xspace_to.empty()) {
TF_ASSIGN_OR_RETURN(auto profiler,
TF_ASSIGN_OR_RETURN(gpu_runner_profiler,
GPURunnerProfiler::Create(opts.xla_gpu_dump_xspace_to,
/*keep_xspace=*/false));
running_options.profiler = profiler.get();
running_options.profiler = gpu_runner_profiler.get();
}
} else if (opts.device_type_str == "host") {
TF_ASSIGN_OR_RETURN(env, xla::GetPjRtEnvironmentForHostCpu());
Expand Down

0 comments on commit 519055d

Please sign in to comment.