[LLM] vLLM: Add Preempt for scheduler (#9568)

Implement Preempt_by_recompute method for vllm.
This commit is contained in:
Xiangyu Tian 2023-12-03 20:16:25 +08:00 committed by GitHub
parent f7e596d85a
commit 5c03651309
3 changed files with 166 additions and 9 deletions

View file

@ -34,6 +34,7 @@
# bigdl-llm Intel specified code change # bigdl-llm Intel specified code change
# #
import enum
import time import time
from typing import Dict, Iterable, List, Optional, Tuple, Union from typing import Dict, Iterable, List, Optional, Tuple, Union
@ -48,6 +49,21 @@ from bigdl.llm.utils.common import invalidInputError
logger = init_logger(__name__) logger = init_logger(__name__)
class PreemptionMode(enum.Enum):
"""Preemption modes.
1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
and swap them back in when the sequences are resumed.
2. Recomputation: Discard the blocks of the preempted sequences and
recompute them when the sequences are resumed, treating the sequences as
new prompts.
bigdl: currently only support RECOMPUTE
"""
SWAP = enum.auto()
RECOMPUTE = enum.auto()
class SchedulerOutputs: class SchedulerOutputs:
def __init__( def __init__(
@ -101,6 +117,7 @@ class FixedWindowScheduler:
self.cleaned: List[int] = [] self.cleaned: List[int] = []
self.kv_cache = kv_cache self.kv_cache = kv_cache
# Co(gc): We no longer have the swapped space as we are not deciding which to swap # Co(gc): We no longer have the swapped space as we are not deciding which to swap
self.swapped: List[SequenceGroup] = []
# bigdl-llm change end # bigdl-llm change end
def add_seq_group(self, seq_group: SequenceGroup) -> None: def add_seq_group(self, seq_group: SequenceGroup) -> None:
@ -147,8 +164,9 @@ class FixedWindowScheduler:
num_curr_seqs = sum(seq_group.get_max_num_running_seqs() num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running) for seq_group in self.running)
num_batched_tokens = 0 num_batched_tokens = 0
# logger.info(f"swap: {self.swapped}, wait: {self.waiting}, run: {self.running}")
if self.waiting: if not self.swapped:
# We restrict how many requests that can be run using these three arguments # We restrict how many requests that can be run using these three arguments
# Co(gc): If there are waiting requests, we will just try to add it into the # Co(gc): If there are waiting requests, we will just try to add it into the
# running state if not exceeds the stage # running state if not exceeds the stage
@ -209,18 +227,70 @@ class FixedWindowScheduler:
num_curr_seqs += num_new_seqs num_curr_seqs += num_new_seqs
scheduled.append(seq_group) scheduled.append(seq_group)
scheduler_outputs = SchedulerOutputs( if scheduled or ignored_seq_groups:
scheduled_seq_groups=scheduled, scheduler_outputs = SchedulerOutputs(
prompt_run=True, scheduled_seq_groups=scheduled,
num_batched_tokens=len(seq_lens) * max(seq_lens) if seq_lens else 0, prompt_run=True,
ignored_seq_groups=ignored_seq_groups, num_batched_tokens=len(seq_lens) * max(seq_lens) if seq_lens else 0,
finished_seqs=finished_seqs, ignored_seq_groups=ignored_seq_groups,
) finished_seqs=finished_seqs,
return scheduler_outputs )
return scheduler_outputs
# Now consider all the requests in decoding stage # Now consider all the requests in decoding stage
self.running = self.policy.sort_by_priority(now, self.running) self.running = self.policy.sort_by_priority(now, self.running)
# Reserve new token slots for the running sequence groups.
running: List[SequenceGroup] = []
preempted: List[SequenceGroup] = []
while self.running:
seq_group = self.running.pop(0)
# while self.seq_ability < 0:
# if self.running:
# # Preempt the lowest-priority sequence groups.
# victim_seq_group = self.running.pop(-1)
# self._preempt(victim_seq_group)
# preempted.append(victim_seq_group)
# else:
# # No other sequence groups can be preempted.
# # Preempt the current sequence group.
# self._preempt(seq_group)
# preempted.append(seq_group)
# break
# else:
# # Append new slots to the sequence group.
# # self._append_slot(seq_group, blocks_to_copy)
running.append(seq_group)
self.running = running
# TODO (txy): inplement below methods
# # Swap in the sequence groups in the SWAPPED state if possible.
# self.swapped = self.policy.sort_by_priority(now, self.swapped)
# if not preempted:
# num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
# for seq_group in self.running)
# while self.swapped:
# seq_group = self.swapped[0]
# # If the sequence group cannot be swapped in, stop.
# # if not self.block_manager.can_swap_in(seq_group):
# # break
# if self.seq_ability <= 0:
# break
# # The total number of sequences in the RUNNING state should not
# # exceed the maximum number of sequences.
# num_new_seqs = seq_group.get_max_num_running_seqs()
# if (num_curr_seqs + num_new_seqs >
# self.scheduler_config.max_num_seqs):
# break
# seq_group = self.swapped.pop(0)
# # self._swap_in(seq_group, blocks_to_swap_in)
# # self._append_slot(seq_group, blocks_to_copy)
# num_curr_seqs += num_new_seqs
# self.running.append(seq_group)
# Each sequence in the generation phase only takes one token slot. # Each sequence in the generation phase only takes one token slot.
# Therefore, the number of batched tokens is equal to the number of # Therefore, the number of batched tokens is equal to the number of
# sequences in the RUNNING state. # sequences in the RUNNING state.
@ -280,3 +350,85 @@ class FixedWindowScheduler:
seq_group for seq_group in self.running seq_group for seq_group in self.running
if not seq_group.is_finished() if not seq_group.is_finished()
] ]
def _preempt(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Optional[Dict[int, int]]=None,
preemption_mode: Optional[PreemptionMode]=None,
) -> None:
# If preemption mode is not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
# (e.g., beam search), recomputation is not currently supported. In
# such a case, we use swapping instead.
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
# As swapped sequences are prioritized over waiting sequences,
# sequence groups with multiple sequences are implicitly prioritized
# over sequence groups with a single sequence.
# TODO(woosuk): Support recomputation for sequence groups with multiple
# sequences. This may require a more sophisticated CUDA kernel.
if preemption_mode is None:
if seq_group.get_max_num_running_seqs() == 1:
preemption_mode = PreemptionMode.RECOMPUTE
else:
preemption_mode = PreemptionMode.SWAP
if preemption_mode == PreemptionMode.RECOMPUTE:
self._preempt_by_recompute(seq_group)
elif preemption_mode == PreemptionMode.SWAP:
self._preempt_by_swap(seq_group, blocks_to_swap_out)
else:
raise AssertionError("Invalid preemption mode.") # noqa
def _preempt_by_recompute(
self,
seq_group: SequenceGroup,
) -> None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
# len(seqs) should be 1
for seq in seqs:
seq.status = SequenceStatus.WAITING
# self.block_manager.free(seq)
if not self.kv_cache[0][0].get(seq.seq_id) is None:
for i in range(len(self.kv_cache)):
for j in range(2):
del self.kv_cache[i][j][seq.seq_id]
# NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue.
self.waiting.insert(0, seq_group)
# TODO (txy): inplement below methods
def _preempt_by_swap(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
) -> None:
self._swap_out(seq_group, blocks_to_swap_out)
self.swapped.append(seq_group)
def _swap_in(
self,
seq_group: SequenceGroup,
blocks_to_swap_in: Dict[int, int],
) -> None:
mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
seq.status = SequenceStatus.RUNNING
def _swap_out(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
) -> None:
if not self.block_manager.can_swap_out(seq_group):
# FIXME(woosuk): Abort the sequence group instead of aborting the
# entire engine.
raise RuntimeError( # noqa
"Aborted due to the lack of CPU swap space. Please increase "
"the swap space to avoid this error.")
mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED

View file

@ -172,6 +172,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
# "return_dict": True, # "return_dict": True,
} }
# pdb.set_trace() # pdb.set_trace()
if self.device.type == 'xpu': if self.device.type == 'xpu':
torch.xpu.empty_cache() torch.xpu.empty_cache()
st_timestamp = time.perf_counter() st_timestamp = time.perf_counter()

View file

@ -21,6 +21,10 @@ from transformers import LlamaConfig
from bigdl.llm.vllm.sequence import SequenceOutputs, SequenceGroupMetadata from bigdl.llm.vllm.sequence import SequenceOutputs, SequenceGroupMetadata
from bigdl.llm.transformers.models.utils import extend_kv_cache from bigdl.llm.transformers.models.utils import extend_kv_cache
from bigdl.llm.vllm.logger import init_logger
logger = init_logger(__name__)
zero_cache_dict = {} zero_cache_dict = {}