Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA GPU] [XLA_GPU_MS_COLLECTIVE] Support round-robin runtime stream assignment #22450

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

terryysun
Copy link
Contributor

With #19026, LHS can overlap appropriate async collectives without deadlock. This PR adds support at runtime where we leverage the overlapping schedule produced by LHS and perform a round-robin stream assignment for collectives.

@terryysun terryysun added the kokoro:force-run Forces CI to rerun label Feb 7, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 7, 2025
@terryysun terryysun requested a review from frgossen February 7, 2025 03:55
@terryysun terryysun changed the title [NVIDIA GPU] [Async Collective Multi-streaming] Support round-robin runtime stream assignment [NVIDIA GPU] [XLA_GPU_MS_COLLECTIVE] Support round-robin runtime stream assignment Feb 13, 2025
Copy link
Member

@frgossen frgossen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this. Multiple stream for collectives will be great to have!

This is a big PR and I think at least the stream assignment, the runtime integration, and some util functions could be three separate PRs.

One thing I am wondering is if we could do this in the latency-hiding scheduler. The scheduler models resources and we could have one per collective stream. That way the scheduler would perform the stream assignment and compose well with it. I'm not saying that would be better but it might be worth thinking about since you seem to run into issues with the LHS (implied by some comments).

inline constexpr int64_t kAsyncStreamTotal =
static_cast<int64_t>(AsyncStreamKind::kMemCpyP2P) + 1;
std::max(static_cast<int64_t>(AsyncStreamKind::kMemCpyP2P) + 1, (int64_t)7);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you extract this 7 as a constant with a meaningful name?

@@ -36,6 +36,21 @@ limitations under the License.
#include "xla/side_effect_util.h"

namespace xla::gpu {
namespace {
static bool is_async_collective(const HloInstruction* instruction) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may be something like this in collective_ops_utils.cc. If not that could be a good place to add it.

@@ -36,6 +36,21 @@ limitations under the License.
#include "xla/side_effect_util.h"

namespace xla::gpu {
namespace {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for anonymous space for static functions

@@ -40,8 +40,10 @@ enum class AsyncStreamKind : int64_t {
kMemCpyP2P = 3, // Stream for MemCpyP2P
};

// Taking the maximum of max stream kind + 1 and 4 (max compute stream) + 2 (max
// collective stream) + 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not clear to me. Is kAsyncStreamTotal a static upper bound? How is that enforced when the number of collective streams can be set per flag xla_gpu_experimental_parallel_collective_overlap_limit? Am I misunderstanding this?

std::make_unique<NcclThunkType>(thunk_info, instr, buffers);

const ExecutionStreamAssignment& stream_assignment =
ir_emitter_context.execution_stream_assignment();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this passes separately? Can we add this to the thunks when emitting them? I think we do something like that for send and recv already

static bool is_async_collective(const HloInstruction* instruction) {
if (instruction->IsAsynchronous()) {
auto opcode = instruction->async_wrapped_opcode();
return opcode == HloOpcode::kAllGather || opcode == HloOpcode::kAllReduce ||
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this covers all collectives. If you move this to collective_ops_utils, we can test it too

@@ -45,7 +60,11 @@ ExecutionStreamAssignment::ExecutionStreamAssignment(
// on the entrypoint computation will be assigned `ExecutionStreamId(0)`, and
// each invocation of `async-start` will result in the target computation
// being assigned a new `ExecutionStreamId`.
ExecutionStreamId next_stream_id = ExecutionStreamId(1);
ExecutionStreamId next_compute_stream_id = ExecutionStreamId(1);
ExecutionStreamId next_collective_stream_id =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the tyoe of the next collective stream id ExecutionStreamId?

@@ -405,7 +405,19 @@ GpuThunkAotCompilationResult::LoadExecutable(
compiler->BufferSizeBytesFunction(),
/*can_share_buffer=*/nullptr));

ExecutionStreamAssignment execution_stream_assignment(hlo_module.get());
ExecutionStreamAssignment execution_stream_assignment(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like duplicate code.

@@ -126,19 +126,43 @@ bool IsAsyncPair(const HloInstruction& from, const HloInstruction& target) {
return IsGpuAsyncStart(from) && IsGpuAsyncDone(target);
}

// Util function for getting replica groups from different data structures.
static std::vector<std::vector<int64_t>> get_replica_groups(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is something like this in the collective ops utils. Can you mve it there and test it separately?

std::vector<int64_t> ids({pair.first, pair.second});
return ids;
});
} else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work for all collectives except collective-permute start? Can we test that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants