[LLM] vLLM: Add Preempt for scheduler (#9568)
Implement Preempt_by_recompute method for vllm.
This commit is contained in:
parent
f7e596d85a
commit
5c03651309
3 changed files with 166 additions and 9 deletions
|
|
@ -34,6 +34,7 @@
|
||||||
# bigdl-llm Intel specified code change
|
# bigdl-llm Intel specified code change
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import enum
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
@ -48,6 +49,21 @@ from bigdl.llm.utils.common import invalidInputError
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PreemptionMode(enum.Enum):
|
||||||
|
"""Preemption modes.
|
||||||
|
|
||||||
|
1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
|
||||||
|
and swap them back in when the sequences are resumed.
|
||||||
|
2. Recomputation: Discard the blocks of the preempted sequences and
|
||||||
|
recompute them when the sequences are resumed, treating the sequences as
|
||||||
|
new prompts.
|
||||||
|
|
||||||
|
bigdl: currently only support RECOMPUTE
|
||||||
|
"""
|
||||||
|
SWAP = enum.auto()
|
||||||
|
RECOMPUTE = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
class SchedulerOutputs:
|
class SchedulerOutputs:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -101,6 +117,7 @@ class FixedWindowScheduler:
|
||||||
self.cleaned: List[int] = []
|
self.cleaned: List[int] = []
|
||||||
self.kv_cache = kv_cache
|
self.kv_cache = kv_cache
|
||||||
# Co(gc): We no longer have the swapped space as we are not deciding which to swap
|
# Co(gc): We no longer have the swapped space as we are not deciding which to swap
|
||||||
|
self.swapped: List[SequenceGroup] = []
|
||||||
# bigdl-llm change end
|
# bigdl-llm change end
|
||||||
|
|
||||||
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||||
|
|
@ -147,8 +164,9 @@ class FixedWindowScheduler:
|
||||||
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||||
for seq_group in self.running)
|
for seq_group in self.running)
|
||||||
num_batched_tokens = 0
|
num_batched_tokens = 0
|
||||||
|
# logger.info(f"swap: {self.swapped}, wait: {self.waiting}, run: {self.running}")
|
||||||
|
|
||||||
if self.waiting:
|
if not self.swapped:
|
||||||
# We restrict how many requests that can be run using these three arguments
|
# We restrict how many requests that can be run using these three arguments
|
||||||
# Co(gc): If there are waiting requests, we will just try to add it into the
|
# Co(gc): If there are waiting requests, we will just try to add it into the
|
||||||
# running state if not exceeds the stage
|
# running state if not exceeds the stage
|
||||||
|
|
@ -209,6 +227,7 @@ class FixedWindowScheduler:
|
||||||
num_curr_seqs += num_new_seqs
|
num_curr_seqs += num_new_seqs
|
||||||
scheduled.append(seq_group)
|
scheduled.append(seq_group)
|
||||||
|
|
||||||
|
if scheduled or ignored_seq_groups:
|
||||||
scheduler_outputs = SchedulerOutputs(
|
scheduler_outputs = SchedulerOutputs(
|
||||||
scheduled_seq_groups=scheduled,
|
scheduled_seq_groups=scheduled,
|
||||||
prompt_run=True,
|
prompt_run=True,
|
||||||
|
|
@ -221,6 +240,57 @@ class FixedWindowScheduler:
|
||||||
# Now consider all the requests in decoding stage
|
# Now consider all the requests in decoding stage
|
||||||
self.running = self.policy.sort_by_priority(now, self.running)
|
self.running = self.policy.sort_by_priority(now, self.running)
|
||||||
|
|
||||||
|
# Reserve new token slots for the running sequence groups.
|
||||||
|
running: List[SequenceGroup] = []
|
||||||
|
preempted: List[SequenceGroup] = []
|
||||||
|
while self.running:
|
||||||
|
seq_group = self.running.pop(0)
|
||||||
|
# while self.seq_ability < 0:
|
||||||
|
# if self.running:
|
||||||
|
# # Preempt the lowest-priority sequence groups.
|
||||||
|
# victim_seq_group = self.running.pop(-1)
|
||||||
|
# self._preempt(victim_seq_group)
|
||||||
|
# preempted.append(victim_seq_group)
|
||||||
|
# else:
|
||||||
|
# # No other sequence groups can be preempted.
|
||||||
|
# # Preempt the current sequence group.
|
||||||
|
# self._preempt(seq_group)
|
||||||
|
# preempted.append(seq_group)
|
||||||
|
# break
|
||||||
|
# else:
|
||||||
|
# # Append new slots to the sequence group.
|
||||||
|
# # self._append_slot(seq_group, blocks_to_copy)
|
||||||
|
running.append(seq_group)
|
||||||
|
self.running = running
|
||||||
|
|
||||||
|
# TODO (txy): inplement below methods
|
||||||
|
# # Swap in the sequence groups in the SWAPPED state if possible.
|
||||||
|
# self.swapped = self.policy.sort_by_priority(now, self.swapped)
|
||||||
|
# if not preempted:
|
||||||
|
# num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||||
|
# for seq_group in self.running)
|
||||||
|
|
||||||
|
# while self.swapped:
|
||||||
|
# seq_group = self.swapped[0]
|
||||||
|
# # If the sequence group cannot be swapped in, stop.
|
||||||
|
# # if not self.block_manager.can_swap_in(seq_group):
|
||||||
|
# # break
|
||||||
|
# if self.seq_ability <= 0:
|
||||||
|
# break
|
||||||
|
|
||||||
|
# # The total number of sequences in the RUNNING state should not
|
||||||
|
# # exceed the maximum number of sequences.
|
||||||
|
# num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||||
|
# if (num_curr_seqs + num_new_seqs >
|
||||||
|
# self.scheduler_config.max_num_seqs):
|
||||||
|
# break
|
||||||
|
|
||||||
|
# seq_group = self.swapped.pop(0)
|
||||||
|
# # self._swap_in(seq_group, blocks_to_swap_in)
|
||||||
|
# # self._append_slot(seq_group, blocks_to_copy)
|
||||||
|
# num_curr_seqs += num_new_seqs
|
||||||
|
# self.running.append(seq_group)
|
||||||
|
|
||||||
# Each sequence in the generation phase only takes one token slot.
|
# Each sequence in the generation phase only takes one token slot.
|
||||||
# Therefore, the number of batched tokens is equal to the number of
|
# Therefore, the number of batched tokens is equal to the number of
|
||||||
# sequences in the RUNNING state.
|
# sequences in the RUNNING state.
|
||||||
|
|
@ -280,3 +350,85 @@ class FixedWindowScheduler:
|
||||||
seq_group for seq_group in self.running
|
seq_group for seq_group in self.running
|
||||||
if not seq_group.is_finished()
|
if not seq_group.is_finished()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _preempt(
|
||||||
|
self,
|
||||||
|
seq_group: SequenceGroup,
|
||||||
|
blocks_to_swap_out: Optional[Dict[int, int]]=None,
|
||||||
|
preemption_mode: Optional[PreemptionMode]=None,
|
||||||
|
) -> None:
|
||||||
|
# If preemption mode is not specified, we determine the mode as follows:
|
||||||
|
# We use recomputation by default since it incurs lower overhead than
|
||||||
|
# swapping. However, when the sequence group has multiple sequences
|
||||||
|
# (e.g., beam search), recomputation is not currently supported. In
|
||||||
|
# such a case, we use swapping instead.
|
||||||
|
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
|
||||||
|
# As swapped sequences are prioritized over waiting sequences,
|
||||||
|
# sequence groups with multiple sequences are implicitly prioritized
|
||||||
|
# over sequence groups with a single sequence.
|
||||||
|
# TODO(woosuk): Support recomputation for sequence groups with multiple
|
||||||
|
# sequences. This may require a more sophisticated CUDA kernel.
|
||||||
|
if preemption_mode is None:
|
||||||
|
if seq_group.get_max_num_running_seqs() == 1:
|
||||||
|
preemption_mode = PreemptionMode.RECOMPUTE
|
||||||
|
else:
|
||||||
|
preemption_mode = PreemptionMode.SWAP
|
||||||
|
if preemption_mode == PreemptionMode.RECOMPUTE:
|
||||||
|
self._preempt_by_recompute(seq_group)
|
||||||
|
elif preemption_mode == PreemptionMode.SWAP:
|
||||||
|
self._preempt_by_swap(seq_group, blocks_to_swap_out)
|
||||||
|
else:
|
||||||
|
raise AssertionError("Invalid preemption mode.") # noqa
|
||||||
|
|
||||||
|
def _preempt_by_recompute(
|
||||||
|
self,
|
||||||
|
seq_group: SequenceGroup,
|
||||||
|
) -> None:
|
||||||
|
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||||
|
# len(seqs) should be 1
|
||||||
|
for seq in seqs:
|
||||||
|
seq.status = SequenceStatus.WAITING
|
||||||
|
# self.block_manager.free(seq)
|
||||||
|
if not self.kv_cache[0][0].get(seq.seq_id) is None:
|
||||||
|
for i in range(len(self.kv_cache)):
|
||||||
|
for j in range(2):
|
||||||
|
del self.kv_cache[i][j][seq.seq_id]
|
||||||
|
|
||||||
|
# NOTE: For FCFS, we insert the preempted sequence group to the front
|
||||||
|
# of the waiting queue.
|
||||||
|
self.waiting.insert(0, seq_group)
|
||||||
|
|
||||||
|
# TODO (txy): inplement below methods
|
||||||
|
def _preempt_by_swap(
|
||||||
|
self,
|
||||||
|
seq_group: SequenceGroup,
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
) -> None:
|
||||||
|
self._swap_out(seq_group, blocks_to_swap_out)
|
||||||
|
self.swapped.append(seq_group)
|
||||||
|
|
||||||
|
def _swap_in(
|
||||||
|
self,
|
||||||
|
seq_group: SequenceGroup,
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
) -> None:
|
||||||
|
mapping = self.block_manager.swap_in(seq_group)
|
||||||
|
blocks_to_swap_in.update(mapping)
|
||||||
|
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||||
|
seq.status = SequenceStatus.RUNNING
|
||||||
|
|
||||||
|
def _swap_out(
|
||||||
|
self,
|
||||||
|
seq_group: SequenceGroup,
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
) -> None:
|
||||||
|
if not self.block_manager.can_swap_out(seq_group):
|
||||||
|
# FIXME(woosuk): Abort the sequence group instead of aborting the
|
||||||
|
# entire engine.
|
||||||
|
raise RuntimeError( # noqa
|
||||||
|
"Aborted due to the lack of CPU swap space. Please increase "
|
||||||
|
"the swap space to avoid this error.")
|
||||||
|
mapping = self.block_manager.swap_out(seq_group)
|
||||||
|
blocks_to_swap_out.update(mapping)
|
||||||
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
|
seq.status = SequenceStatus.SWAPPED
|
||||||
|
|
|
||||||
|
|
@ -172,6 +172,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
# "return_dict": True,
|
# "return_dict": True,
|
||||||
}
|
}
|
||||||
# pdb.set_trace()
|
# pdb.set_trace()
|
||||||
|
|
||||||
if self.device.type == 'xpu':
|
if self.device.type == 'xpu':
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
st_timestamp = time.perf_counter()
|
st_timestamp = time.perf_counter()
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,10 @@ from transformers import LlamaConfig
|
||||||
|
|
||||||
from bigdl.llm.vllm.sequence import SequenceOutputs, SequenceGroupMetadata
|
from bigdl.llm.vllm.sequence import SequenceOutputs, SequenceGroupMetadata
|
||||||
from bigdl.llm.transformers.models.utils import extend_kv_cache
|
from bigdl.llm.transformers.models.utils import extend_kv_cache
|
||||||
|
from bigdl.llm.vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
zero_cache_dict = {}
|
zero_cache_dict = {}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue