Skip to content

vllm.v1.core.sched.async_scheduler

logger module-attribute

logger = init_logger(__name__)

AsyncScheduler

Bases: Scheduler

Source code in vllm/v1/core/sched/async_scheduler.py
class AsyncScheduler(Scheduler):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # reusable read-only placeholder list for speculative decoding.
        self._spec_token_placeholders: list[int] = [-1] * self.num_spec_tokens

    def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
        super()._update_after_schedule(scheduler_output)
        spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
        for req_id in scheduler_output.num_scheduled_tokens:
            request = self.requests[req_id]
            if request.is_prefill_chunk:
                continue

            scheduler_output.pending_structured_output_tokens |= (
                request.use_structured_output and request.num_output_placeholders > 0
            )
            # The request will generate a new token plus num_spec_tokens
            # in this scheduling step.
            cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
            request.num_output_placeholders += 1 + cur_num_spec_tokens
            # Add placeholders for the new draft/spec tokens.
            # We will update the actual spec token ids in the worker process.
            request.spec_token_ids = self._spec_token_placeholders

    def _update_request_with_output(
        self, request: Request, new_token_ids: list[int]
    ) -> tuple[list[int], bool]:
        if request.discard_latest_async_tokens:
            # If the request is force preempted in reset_prefix_cache, we
            # should discard the latest async token.
            request.discard_latest_async_tokens = False
            return [], False

        status_before_update = request.status
        new_token_ids, stopped = super()._update_request_with_output(
            request, new_token_ids
        )

        # Update the number of output placeholders.
        request.num_output_placeholders -= len(new_token_ids)
        assert request.num_output_placeholders >= 0

        # Cache the new tokens. Preempted requests should be skipped.
        if status_before_update == RequestStatus.RUNNING:
            self.kv_cache_manager.cache_blocks(
                request, request.num_computed_tokens - request.num_output_placeholders
            )
        return new_token_ids, stopped

_spec_token_placeholders instance-attribute

_spec_token_placeholders: list[int] = [-1] * num_spec_tokens

__init__

__init__(*args, **kwargs) -> None
Source code in vllm/v1/core/sched/async_scheduler.py
def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    # reusable read-only placeholder list for speculative decoding.
    self._spec_token_placeholders: list[int] = [-1] * self.num_spec_tokens

_update_after_schedule

_update_after_schedule(
    scheduler_output: SchedulerOutput,
) -> None
Source code in vllm/v1/core/sched/async_scheduler.py
def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
    super()._update_after_schedule(scheduler_output)
    spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
    for req_id in scheduler_output.num_scheduled_tokens:
        request = self.requests[req_id]
        if request.is_prefill_chunk:
            continue

        scheduler_output.pending_structured_output_tokens |= (
            request.use_structured_output and request.num_output_placeholders > 0
        )
        # The request will generate a new token plus num_spec_tokens
        # in this scheduling step.
        cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
        request.num_output_placeholders += 1 + cur_num_spec_tokens
        # Add placeholders for the new draft/spec tokens.
        # We will update the actual spec token ids in the worker process.
        request.spec_token_ids = self._spec_token_placeholders

_update_request_with_output

_update_request_with_output(
    request: Request, new_token_ids: list[int]
) -> tuple[list[int], bool]
Source code in vllm/v1/core/sched/async_scheduler.py
def _update_request_with_output(
    self, request: Request, new_token_ids: list[int]
) -> tuple[list[int], bool]:
    if request.discard_latest_async_tokens:
        # If the request is force preempted in reset_prefix_cache, we
        # should discard the latest async token.
        request.discard_latest_async_tokens = False
        return [], False

    status_before_update = request.status
    new_token_ids, stopped = super()._update_request_with_output(
        request, new_token_ids
    )

    # Update the number of output placeholders.
    request.num_output_placeholders -= len(new_token_ids)
    assert request.num_output_placeholders >= 0

    # Cache the new tokens. Preempted requests should be skipped.
    if status_before_update == RequestStatus.RUNNING:
        self.kv_cache_manager.cache_blocks(
            request, request.num_computed_tokens - request.num_output_placeholders
        )
    return new_token_ids, stopped