Skip to content

Commit

Permalink
[V1][Minor] Move scheduler outputs to a separate file (vllm-project#1…
Browse files Browse the repository at this point in the history
…3062)

Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon authored Feb 11, 2025
1 parent 91e8767 commit 2ff4857
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 89 deletions.
89 changes: 3 additions & 86 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
# SPDX-License-Identifier: Apache-2.0

from collections import deque
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union

from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus

if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange

logger = init_logger(__name__)


Expand Down Expand Up @@ -600,80 +594,3 @@ def make_stats(self) -> SchedulerStats:
num_waiting_reqs=len(self.waiting),
gpu_cache_usage=self.kv_cache_manager.usage,
)


@dataclass
class NewRequestData:

req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
mm_inputs: List["MultiModalKwargs"]
mm_hashes: List[str]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
block_ids: List[int]
num_computed_tokens: int
lora_request: Optional[LoRARequest]

@classmethod
def from_request(
cls,
request: Request,
block_ids: List[int],
num_computed_tokens: int,
) -> "NewRequestData":
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
mm_inputs=request.mm_inputs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
lora_request=request.lora_request,
)


@dataclass
class CachedRequestData:

req_id: str
# If resumed_from_preemption is False, new_block_ids will be appended to
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool
new_block_ids: List[int]
num_computed_tokens: int

@classmethod
def from_request(
cls,
request: Request,
resumed_from_preemption: bool,
new_block_ids: List[int],
num_computed_tokens: int,
) -> "CachedRequestData":
return cls(
req_id=request.request_id,
resumed_from_preemption=resumed_from_preemption,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
)


@dataclass
class SchedulerOutput:

scheduled_new_reqs: List[NewRequestData]
scheduled_cached_reqs: List[CachedRequestData]

num_scheduled_tokens: Dict[str, int]
total_num_scheduled_tokens: int
scheduled_encoder_inputs: Dict[str, List[int]]
num_common_prefix_blocks: int

finished_req_ids: Set[str]
free_encoder_input_ids: List[Tuple[str, int]]
108 changes: 108 additions & 0 deletions vllm/v1/core/scheduler_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple

if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request


@dataclass
class NewRequestData:

req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
mm_inputs: List["MultiModalKwargs"]
mm_hashes: List[str]
mm_positions: List["PlaceholderRange"]
sampling_params: "SamplingParams"
block_ids: List[int]
num_computed_tokens: int
lora_request: Optional["LoRARequest"]

@classmethod
def from_request(
cls,
request: "Request",
block_ids: List[int],
num_computed_tokens: int,
) -> "NewRequestData":
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
mm_inputs=request.mm_inputs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
lora_request=request.lora_request,
)


@dataclass
class CachedRequestData:

req_id: str
# If resumed_from_preemption is False, new_block_ids will be appended to
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool
new_block_ids: List[int]
num_computed_tokens: int

@classmethod
def from_request(
cls,
request: "Request",
resumed_from_preemption: bool,
new_block_ids: List[int],
num_computed_tokens: int,
) -> "CachedRequestData":
return cls(
req_id=request.request_id,
resumed_from_preemption=resumed_from_preemption,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
)


@dataclass
class SchedulerOutput:

# List of the requests that are scheduled for the first time.
# We cache the request's data in each worker process, so that we don't
# need to re-send it every scheduling step.
scheduled_new_reqs: List[NewRequestData]
# List of the requests that have been scheduled before.
# Since the request's data is already cached in the worker processes,
# we only send the diff to minimize the communication cost.
scheduled_cached_reqs: List[CachedRequestData]

# req_id -> num_scheduled_tokens
# Number of tokens scheduled for each request.
num_scheduled_tokens: Dict[str, int]
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens: int
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.
scheduled_encoder_inputs: Dict[str, List[int]]
# Number of common prefix blocks for all requests.
# This can be used for cascade attention.
num_common_prefix_blocks: int

# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests.
finished_req_ids: Set[str]
# List of (req_id, encoder_input_index) tuples.
# Used to free the encoder cache.
free_encoder_input_ids: List[Tuple[str, int]]
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin

if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.core.scheduler_output import SchedulerOutput

logger = init_logger(__name__)

Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import GiB_bytes
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

logger = init_logger(__name__)

if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.core.scheduler_output import SchedulerOutput


class Worker:
Expand Down

0 comments on commit 2ff4857

Please sign in to comment.