diff --git a/test/spmd/test_train_spmd_imagenet.py b/test/spmd/test_train_spmd_imagenet.py index bf544b01d73b..021306698fba 100644 --- a/test/spmd/test_train_spmd_imagenet.py +++ b/test/spmd/test_train_spmd_imagenet.py @@ -344,7 +344,7 @@ def train_loop_fn(loader, epoch): if step == profile_step and epoch == profile_epoch: xm.wait_device_ops() import tempfile - xp.trace_detached('127.0.0.1:9012', profile_logdir, profile_duration or 20000) + xp.trace_detached(f'127.0.0.1:{FLAGS.profiler_port}', profile_logdir, profile_duration or 20000) def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 @@ -389,7 +389,7 @@ def test_loop_fn(loader, epoch): if __name__ == '__main__': if FLAGS.profile: - server = xp.start_server(9012) + server = xp.start_server(FLAGS.profiler_port) torch.set_default_dtype(torch.float32) accuracy = train_imagenet()