[LLM] vLLM: Add Preempt for scheduler (#9568)
Implement Preempt_by_recompute method for vllm.
This commit is contained in:
parent
f7e596d85a
commit
5c03651309
3 changed files with 166 additions and 9 deletions
|
|
@ -34,6 +34,7 @@
|
|||
# bigdl-llm Intel specified code change
|
||||
#
|
||||
|
||||
import enum
|
||||
import time
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
|
|
@ -48,6 +49,21 @@ from bigdl.llm.utils.common import invalidInputError
|
|||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PreemptionMode(enum.Enum):
|
||||
"""Preemption modes.
|
||||
|
||||
1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
|
||||
and swap them back in when the sequences are resumed.
|
||||
2. Recomputation: Discard the blocks of the preempted sequences and
|
||||
recompute them when the sequences are resumed, treating the sequences as
|
||||
new prompts.
|
||||
|
||||
bigdl: currently only support RECOMPUTE
|
||||
"""
|
||||
SWAP = enum.auto()
|
||||
RECOMPUTE = enum.auto()
|
||||
|
||||
|
||||
class SchedulerOutputs:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -101,6 +117,7 @@ class FixedWindowScheduler:
|
|||
self.cleaned: List[int] = []
|
||||
self.kv_cache = kv_cache
|
||||
# Co(gc): We no longer have the swapped space as we are not deciding which to swap
|
||||
self.swapped: List[SequenceGroup] = []
|
||||
# bigdl-llm change end
|
||||
|
||||
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||
|
|
@ -147,8 +164,9 @@ class FixedWindowScheduler:
|
|||
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||
for seq_group in self.running)
|
||||
num_batched_tokens = 0
|
||||
# logger.info(f"swap: {self.swapped}, wait: {self.waiting}, run: {self.running}")
|
||||
|
||||
if self.waiting:
|
||||
if not self.swapped:
|
||||
# We restrict how many requests that can be run using these three arguments
|
||||
# Co(gc): If there are waiting requests, we will just try to add it into the
|
||||
# running state if not exceeds the stage
|
||||
|
|
@ -209,6 +227,7 @@ class FixedWindowScheduler:
|
|||
num_curr_seqs += num_new_seqs
|
||||
scheduled.append(seq_group)
|
||||
|
||||
if scheduled or ignored_seq_groups:
|
||||
scheduler_outputs = SchedulerOutputs(
|
||||
scheduled_seq_groups=scheduled,
|
||||
prompt_run=True,
|
||||
|
|
@ -221,6 +240,57 @@ class FixedWindowScheduler:
|
|||
# Now consider all the requests in decoding stage
|
||||
self.running = self.policy.sort_by_priority(now, self.running)
|
||||
|
||||
# Reserve new token slots for the running sequence groups.
|
||||
running: List[SequenceGroup] = []
|
||||
preempted: List[SequenceGroup] = []
|
||||
while self.running:
|
||||
seq_group = self.running.pop(0)
|
||||
# while self.seq_ability < 0:
|
||||
# if self.running:
|
||||
# # Preempt the lowest-priority sequence groups.
|
||||
# victim_seq_group = self.running.pop(-1)
|
||||
# self._preempt(victim_seq_group)
|
||||
# preempted.append(victim_seq_group)
|
||||
# else:
|
||||
# # No other sequence groups can be preempted.
|
||||
# # Preempt the current sequence group.
|
||||
# self._preempt(seq_group)
|
||||
# preempted.append(seq_group)
|
||||
# break
|
||||
# else:
|
||||
# # Append new slots to the sequence group.
|
||||
# # self._append_slot(seq_group, blocks_to_copy)
|
||||
running.append(seq_group)
|
||||
self.running = running
|
||||
|
||||
# TODO (txy): inplement below methods
|
||||
# # Swap in the sequence groups in the SWAPPED state if possible.
|
||||
# self.swapped = self.policy.sort_by_priority(now, self.swapped)
|
||||
# if not preempted:
|
||||
# num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||
# for seq_group in self.running)
|
||||
|
||||
# while self.swapped:
|
||||
# seq_group = self.swapped[0]
|
||||
# # If the sequence group cannot be swapped in, stop.
|
||||
# # if not self.block_manager.can_swap_in(seq_group):
|
||||
# # break
|
||||
# if self.seq_ability <= 0:
|
||||
# break
|
||||
|
||||
# # The total number of sequences in the RUNNING state should not
|
||||
# # exceed the maximum number of sequences.
|
||||
# num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||
# if (num_curr_seqs + num_new_seqs >
|
||||
# self.scheduler_config.max_num_seqs):
|
||||
# break
|
||||
|
||||
# seq_group = self.swapped.pop(0)
|
||||
# # self._swap_in(seq_group, blocks_to_swap_in)
|
||||
# # self._append_slot(seq_group, blocks_to_copy)
|
||||
# num_curr_seqs += num_new_seqs
|
||||
# self.running.append(seq_group)
|
||||
|
||||
# Each sequence in the generation phase only takes one token slot.
|
||||
# Therefore, the number of batched tokens is equal to the number of
|
||||
# sequences in the RUNNING state.
|
||||
|
|
@ -280,3 +350,85 @@ class FixedWindowScheduler:
|
|||
seq_group for seq_group in self.running
|
||||
if not seq_group.is_finished()
|
||||
]
|
||||
|
||||
def _preempt(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_swap_out: Optional[Dict[int, int]]=None,
|
||||
preemption_mode: Optional[PreemptionMode]=None,
|
||||
) -> None:
|
||||
# If preemption mode is not specified, we determine the mode as follows:
|
||||
# We use recomputation by default since it incurs lower overhead than
|
||||
# swapping. However, when the sequence group has multiple sequences
|
||||
# (e.g., beam search), recomputation is not currently supported. In
|
||||
# such a case, we use swapping instead.
|
||||
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
|
||||
# As swapped sequences are prioritized over waiting sequences,
|
||||
# sequence groups with multiple sequences are implicitly prioritized
|
||||
# over sequence groups with a single sequence.
|
||||
# TODO(woosuk): Support recomputation for sequence groups with multiple
|
||||
# sequences. This may require a more sophisticated CUDA kernel.
|
||||
if preemption_mode is None:
|
||||
if seq_group.get_max_num_running_seqs() == 1:
|
||||
preemption_mode = PreemptionMode.RECOMPUTE
|
||||
else:
|
||||
preemption_mode = PreemptionMode.SWAP
|
||||
if preemption_mode == PreemptionMode.RECOMPUTE:
|
||||
self._preempt_by_recompute(seq_group)
|
||||
elif preemption_mode == PreemptionMode.SWAP:
|
||||
self._preempt_by_swap(seq_group, blocks_to_swap_out)
|
||||
else:
|
||||
raise AssertionError("Invalid preemption mode.") # noqa
|
||||
|
||||
def _preempt_by_recompute(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
) -> None:
|
||||
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
# len(seqs) should be 1
|
||||
for seq in seqs:
|
||||
seq.status = SequenceStatus.WAITING
|
||||
# self.block_manager.free(seq)
|
||||
if not self.kv_cache[0][0].get(seq.seq_id) is None:
|
||||
for i in range(len(self.kv_cache)):
|
||||
for j in range(2):
|
||||
del self.kv_cache[i][j][seq.seq_id]
|
||||
|
||||
# NOTE: For FCFS, we insert the preempted sequence group to the front
|
||||
# of the waiting queue.
|
||||
self.waiting.insert(0, seq_group)
|
||||
|
||||
# TODO (txy): inplement below methods
|
||||
def _preempt_by_swap(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
) -> None:
|
||||
self._swap_out(seq_group, blocks_to_swap_out)
|
||||
self.swapped.append(seq_group)
|
||||
|
||||
def _swap_in(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
) -> None:
|
||||
mapping = self.block_manager.swap_in(seq_group)
|
||||
blocks_to_swap_in.update(mapping)
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
def _swap_out(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
) -> None:
|
||||
if not self.block_manager.can_swap_out(seq_group):
|
||||
# FIXME(woosuk): Abort the sequence group instead of aborting the
|
||||
# entire engine.
|
||||
raise RuntimeError( # noqa
|
||||
"Aborted due to the lack of CPU swap space. Please increase "
|
||||
"the swap space to avoid this error.")
|
||||
mapping = self.block_manager.swap_out(seq_group)
|
||||
blocks_to_swap_out.update(mapping)
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
seq.status = SequenceStatus.SWAPPED
|
||||
|
|
|
|||
|
|
@ -172,6 +172,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
# "return_dict": True,
|
||||
}
|
||||
# pdb.set_trace()
|
||||
|
||||
if self.device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
st_timestamp = time.perf_counter()
|
||||
|
|
|
|||
|
|
@ -21,6 +21,10 @@ from transformers import LlamaConfig
|
|||
|
||||
from bigdl.llm.vllm.sequence import SequenceOutputs, SequenceGroupMetadata
|
||||
from bigdl.llm.transformers.models.utils import extend_kv_cache
|
||||
from bigdl.llm.vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
zero_cache_dict = {}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue