diff --git a/python/llm/src/bigdl/llm/vllm/core/scheduler.py b/python/llm/src/bigdl/llm/vllm/core/scheduler.py index e5f51b94..b41ea166 100644 --- a/python/llm/src/bigdl/llm/vllm/core/scheduler.py +++ b/python/llm/src/bigdl/llm/vllm/core/scheduler.py @@ -34,6 +34,7 @@ # bigdl-llm Intel specified code change # +import enum import time from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -48,6 +49,21 @@ from bigdl.llm.utils.common import invalidInputError logger = init_logger(__name__) +class PreemptionMode(enum.Enum): + """Preemption modes. + + 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory + and swap them back in when the sequences are resumed. + 2. Recomputation: Discard the blocks of the preempted sequences and + recompute them when the sequences are resumed, treating the sequences as + new prompts. + + bigdl: currently only support RECOMPUTE + """ + SWAP = enum.auto() + RECOMPUTE = enum.auto() + + class SchedulerOutputs: def __init__( @@ -101,6 +117,7 @@ class FixedWindowScheduler: self.cleaned: List[int] = [] self.kv_cache = kv_cache # Co(gc): We no longer have the swapped space as we are not deciding which to swap + self.swapped: List[SequenceGroup] = [] # bigdl-llm change end def add_seq_group(self, seq_group: SequenceGroup) -> None: @@ -147,8 +164,9 @@ class FixedWindowScheduler: num_curr_seqs = sum(seq_group.get_max_num_running_seqs() for seq_group in self.running) num_batched_tokens = 0 + # logger.info(f"swap: {self.swapped}, wait: {self.waiting}, run: {self.running}") - if self.waiting: + if not self.swapped: # We restrict how many requests that can be run using these three arguments # Co(gc): If there are waiting requests, we will just try to add it into the # running state if not exceeds the stage @@ -209,18 +227,70 @@ class FixedWindowScheduler: num_curr_seqs += num_new_seqs scheduled.append(seq_group) - scheduler_outputs = SchedulerOutputs( - scheduled_seq_groups=scheduled, - prompt_run=True, - num_batched_tokens=len(seq_lens) * max(seq_lens) if seq_lens else 0, - ignored_seq_groups=ignored_seq_groups, - finished_seqs=finished_seqs, - ) - return scheduler_outputs + if scheduled or ignored_seq_groups: + scheduler_outputs = SchedulerOutputs( + scheduled_seq_groups=scheduled, + prompt_run=True, + num_batched_tokens=len(seq_lens) * max(seq_lens) if seq_lens else 0, + ignored_seq_groups=ignored_seq_groups, + finished_seqs=finished_seqs, + ) + return scheduler_outputs # Now consider all the requests in decoding stage self.running = self.policy.sort_by_priority(now, self.running) + # Reserve new token slots for the running sequence groups. + running: List[SequenceGroup] = [] + preempted: List[SequenceGroup] = [] + while self.running: + seq_group = self.running.pop(0) + # while self.seq_ability < 0: + # if self.running: + # # Preempt the lowest-priority sequence groups. + # victim_seq_group = self.running.pop(-1) + # self._preempt(victim_seq_group) + # preempted.append(victim_seq_group) + # else: + # # No other sequence groups can be preempted. + # # Preempt the current sequence group. + # self._preempt(seq_group) + # preempted.append(seq_group) + # break + # else: + # # Append new slots to the sequence group. + # # self._append_slot(seq_group, blocks_to_copy) + running.append(seq_group) + self.running = running + + # TODO (txy): inplement below methods + # # Swap in the sequence groups in the SWAPPED state if possible. + # self.swapped = self.policy.sort_by_priority(now, self.swapped) + # if not preempted: + # num_curr_seqs = sum(seq_group.get_max_num_running_seqs() + # for seq_group in self.running) + + # while self.swapped: + # seq_group = self.swapped[0] + # # If the sequence group cannot be swapped in, stop. + # # if not self.block_manager.can_swap_in(seq_group): + # # break + # if self.seq_ability <= 0: + # break + + # # The total number of sequences in the RUNNING state should not + # # exceed the maximum number of sequences. + # num_new_seqs = seq_group.get_max_num_running_seqs() + # if (num_curr_seqs + num_new_seqs > + # self.scheduler_config.max_num_seqs): + # break + + # seq_group = self.swapped.pop(0) + # # self._swap_in(seq_group, blocks_to_swap_in) + # # self._append_slot(seq_group, blocks_to_copy) + # num_curr_seqs += num_new_seqs + # self.running.append(seq_group) + # Each sequence in the generation phase only takes one token slot. # Therefore, the number of batched tokens is equal to the number of # sequences in the RUNNING state. @@ -280,3 +350,85 @@ class FixedWindowScheduler: seq_group for seq_group in self.running if not seq_group.is_finished() ] + + def _preempt( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Optional[Dict[int, int]]=None, + preemption_mode: Optional[PreemptionMode]=None, + ) -> None: + # If preemption mode is not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. + # FIXME(woosuk): This makes our scheduling policy a bit bizarre. + # As swapped sequences are prioritized over waiting sequences, + # sequence groups with multiple sequences are implicitly prioritized + # over sequence groups with a single sequence. + # TODO(woosuk): Support recomputation for sequence groups with multiple + # sequences. This may require a more sophisticated CUDA kernel. + if preemption_mode is None: + if seq_group.get_max_num_running_seqs() == 1: + preemption_mode = PreemptionMode.RECOMPUTE + else: + preemption_mode = PreemptionMode.SWAP + if preemption_mode == PreemptionMode.RECOMPUTE: + self._preempt_by_recompute(seq_group) + elif preemption_mode == PreemptionMode.SWAP: + self._preempt_by_swap(seq_group, blocks_to_swap_out) + else: + raise AssertionError("Invalid preemption mode.") # noqa + + def _preempt_by_recompute( + self, + seq_group: SequenceGroup, + ) -> None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + # len(seqs) should be 1 + for seq in seqs: + seq.status = SequenceStatus.WAITING + # self.block_manager.free(seq) + if not self.kv_cache[0][0].get(seq.seq_id) is None: + for i in range(len(self.kv_cache)): + for j in range(2): + del self.kv_cache[i][j][seq.seq_id] + + # NOTE: For FCFS, we insert the preempted sequence group to the front + # of the waiting queue. + self.waiting.insert(0, seq_group) + + # TODO (txy): inplement below methods + def _preempt_by_swap( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + ) -> None: + self._swap_out(seq_group, blocks_to_swap_out) + self.swapped.append(seq_group) + + def _swap_in( + self, + seq_group: SequenceGroup, + blocks_to_swap_in: Dict[int, int], + ) -> None: + mapping = self.block_manager.swap_in(seq_group) + blocks_to_swap_in.update(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + seq.status = SequenceStatus.RUNNING + + def _swap_out( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + ) -> None: + if not self.block_manager.can_swap_out(seq_group): + # FIXME(woosuk): Abort the sequence group instead of aborting the + # entire engine. + raise RuntimeError( # noqa + "Aborted due to the lack of CPU swap space. Please increase " + "the swap space to avoid this error.") + mapping = self.block_manager.swap_out(seq_group) + blocks_to_swap_out.update(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq.status = SequenceStatus.SWAPPED diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index b53bfc8f..6268d96d 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -172,6 +172,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): # "return_dict": True, } # pdb.set_trace() + if self.device.type == 'xpu': torch.xpu.empty_cache() st_timestamp = time.perf_counter() diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py index 6ab9e109..88e94728 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py @@ -21,6 +21,10 @@ from transformers import LlamaConfig from bigdl.llm.vllm.sequence import SequenceOutputs, SequenceGroupMetadata from bigdl.llm.transformers.models.utils import extend_kv_cache +from bigdl.llm.vllm.logger import init_logger + +logger = init_logger(__name__) + zero_cache_dict = {}