Add vLLM-XPU version's README/examples (#9536)

* test

* test

* fix last kv cache

* add xpu readme

* remove numactl for xpu example

* fix link error

* update max_num_batched_tokens logic

* add explaination

* add xpu environement version requirement

* refine gpu memory

* fix

* fix style
This commit is contained in:
Guancheng Fu 2023-11-28 09:44:03 +08:00 committed by GitHub
parent b6c3520748
commit 963a5c8d79
11 changed files with 287 additions and 60 deletions

View file

@ -31,7 +31,7 @@ pip3 install "pydantic<2" # Required for OpenAI server.
### 2. Configure recommended environment variables
```bash
source bigdl-llm-init
source bigdl-llm-init -t
```
### 3. Offline inference/Service
@ -55,9 +55,12 @@ To fully utilize the continuous batching feature of the `vLLM`, you can send req
```bash
#!/bin/bash
numactl -C 48-95 -m 1 python -m bigdl.llm.vllm.examples.api_server \
# You may also want to adjust the `--max-num-batched-tokens` argument, it indicates the hard limit
# of batched prompt length the server will accept
numactl -C 48-95 -m 1 python -m bigdl.llm.vllm.entrypoints.openai.api_server \
--model /MODEL_PATH/Llama-2-7b-chat-hf-bigdl/ --port 8000 \
--load-format 'auto' --device cpu --dtype bfloat16
--load-format 'auto' --device cpu --dtype bfloat16 \
--max-num-batched-tokens 4096
```
Then you can access the api server as follows:
@ -80,12 +83,12 @@ Currently we have only supported LLaMA family model (including `llama`, `vicuna`
#### 4.1 Add model code
Create or clone the Pytorch model code to `./models`.
Create or clone the Pytorch model code to `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/models`.
#### 4.2 Rewrite the forward methods
Refering to `./models/bigdl_llama.py`, it's necessary to maintain a `kv_cache`, which is a nested list of dictionary that maps `req_id` to a three-dimensional tensor **(the structure may vary from models)**. Before the model's actual `forward` method, you could prepare a `past_key_values` according to current `req_id`, and after that you need to update the `kv_cache` with `output.past_key_values`. The clearence will be executed when the request is finished.
Refering to `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py`, it's necessary to maintain a `kv_cache`, which is a nested list of dictionary that maps `req_id` to a three-dimensional tensor **(the structure may vary from models)**. Before the model's actual `forward` method, you could prepare a `past_key_values` according to current `req_id`, and after that you need to update the `kv_cache` with `output.past_key_values`. The clearence will be executed when the request is finished.
#### 4.3 Register new model
Finally, register your `*ForCausalLM` class to the _MODEL_REGISTRY in `./models/model_loader.py`.
Finally, register your `*ForCausalLM` class to the _MODEL_REGISTRY in `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/model_loader.py`.

View file

@ -14,7 +14,7 @@
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/examples/offline_inference.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
@ -31,8 +31,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from bigdl.llm.vllm.examples.llm import LLM
from bigdl.llm.vllm.structure.sampling_params import SamplingParams
from bigdl.llm.vllm.entrypoints.llm import LLM
from bigdl.llm.vllm.sampling_params import SamplingParams
# Sample prompts.
prompts = [

View file

@ -4,6 +4,7 @@ This folder contains examples of running BigDL-LLM on Intel GPU:
- [HF-Transformers-AutoModels](HF-Transformers-AutoModels): running any ***Hugging Face Transformers*** model on BigDL-LLM (using the standard AutoModel APIs)
- [QLoRA-FineTuning](QLoRA-FineTuning): running ***QLoRA finetuning*** using BigDL-LLM on Intel GPUs
- [vLLM-Serving](vLLM-Serving): running ***vLLM*** serving framework on intel GPUs (with BigDL-LLM low-bit optimized models)
- [Deepspeed-AutoTP](Deepspeed-AutoTP): running distributed inference using ***DeepSpeed AutoTP*** (with BigDL-LLM low-bit optimized models) on Intel GPUs
- [PyTorch-Models](PyTorch-Models): running any PyTorch model on BigDL-LLM (with "one-line code change")

View file

@ -0,0 +1,109 @@
# vLLM continuous batching on Intel GPUs (experimental support)
This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with BigDL-LLM low-bits optimizations).
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.2.1.post1).
## Example: Serving LLaMA2-7B using Intel GPU
In this example, we will run Llama2-7b model using Arc A770 and provide `OpenAI-compatible` interface for users.
### 0. Environment
To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit. Please check the requirements at [here](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU#requirements).
After install the toolkit, run the following commands in your environment before starting vLLM GPU:
```bash
source /opt/intel/oneapi/setvars.sh
# sycl-ls will list all the compatible Intel GPUs in your environment
sycl-ls
# Example output with one Arc A770:
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device 1.2 [2023.16.7.0.21_160000]
[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i9-13900K 3.0 [2023.16.7.0.21_160000]
[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics 3.0 [23.17.26241.33]
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26241]
```
### 1. Install
To run vLLM continuous batching on Intel GPUs, install the dependencies as follows:
```bash
# First create an conda environment
conda create -n bigdl-vllm python==3.9
conda activate bigdl-vllm
# Install dependencies
pip3 install psutil
pip3 install sentencepiece # Required for LLaMA tokenizer.
pip3 install numpy
pip3 install "transformers>=4.33.1" # Required for Code Llama.
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
pip3 install fastapi
pip3 install "uvicorn[standard]"
pip3 install "pydantic<2" # Required for OpenAI server.
```
### 2. Configure recommended environment variables
```bash
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
```
### 3. Offline inference/Service
#### Offline inference
To run offline inference using vLLM for a quick impression, use the following example:
```bash
#!/bin/bash
# Please first modify the MODEL_PATH in offline_inference.py
python offline_inference.py
```
#### Service
To fully utilize the continuous batching feature of the `vLLM`, you can send requests to the service using curl or other similar methods. The requests sent to the engine will be batched at token level. Queries will be executed in the same `forward` step of the LLM and be removed when they are finished instead of waiting for all sequences to be finished.
```bash
#!/bin/bash
# You may also want to adjust the `--max-num-batched-tokens` argument, it indicates the hard limit
# of batched prompt length the server will accept
python -m bigdl.llm.vllm.entrypoints.openai.api_server \
--model /MODEL_PATH/Llama-2-7b-chat-hf/ --port 8000 \
--load-format 'auto' --device xpu --dtype bfloat16 \
--max-num-batched-tokens 4096
```
Then you can access the api server as follows:
```bash
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "/MODEL_PATH/Llama-2-7b-chat-hf-bigdl/",
"prompt": "San Francisco is a",
"max_tokens": 128,
"temperature": 0
}' &
```
### 4. (Optional) Add a new model
Currently we have only supported LLaMA family model (including `llama`, `vicuna`, `llama-2`, etc.). To use aother model, you may need add some adaptions.
#### 4.1 Add model code
Create or clone the Pytorch model code to `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/models`.
#### 4.2 Rewrite the forward methods
Refering to `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py`, it's necessary to maintain a `kv_cache`, which is a nested list of dictionary that maps `req_id` to a three-dimensional tensor **(the structure may vary from models)**. Before the model's actual `forward` method, you could prepare a `past_key_values` according to current `req_id`, and after that you need to update the `kv_cache` with `output.past_key_values`. The clearence will be executed when the request is finished.
#### 4.3 Register new model
Finally, register your `*ForCausalLM` class to the _MODEL_REGISTRY in `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/model_loader.py`.

View file

@ -0,0 +1,57 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/examples/offline_inference.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from bigdl.llm.vllm.entrypoints.llm import LLM
from bigdl.llm.vllm.sampling_params import SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
# llm = LLM(model="facebook/opt-125m")
llm = LLM(model="YOUR_MODEL_PATH", dtype="bfloat16", device="xpu")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View file

@ -79,6 +79,7 @@ class FixedWindowScheduler:
def __init__(
self,
scheduler_config: SchedulerConfig,
kv_cache: Optional,
) -> None:
self.scheduler_config = scheduler_config
self.prompt_limit = min(self.scheduler_config.max_model_len,
@ -98,6 +99,7 @@ class FixedWindowScheduler:
# Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = []
self.cleaned: List[int] = []
self.kv_cache = kv_cache
# Co(gc): We no longer have the swapped space as we are not deciding which to swap
# bigdl-llm change end
@ -150,6 +152,8 @@ class FixedWindowScheduler:
# We restrict how many requests that can be run using these three arguments
# Co(gc): If there are waiting requests, we will just try to add it into the
# running state if not exceeds the stage
# Co(gc): Record seq_len for prefill requests
seq_lens = []
# Co(gc): prefilled requests are prioritized over decoding stage requests
while self.waiting:
seq_group = self.waiting[0]
@ -178,7 +182,9 @@ class FixedWindowScheduler:
# bigdl-llm change end
# If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens >
new_seq_lens = seq_lens + [num_prompt_tokens]
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
if (num_batched_tokens >
self.scheduler_config.max_num_batched_tokens):
break
@ -192,6 +198,8 @@ class FixedWindowScheduler:
seq_group = self.waiting.pop(0)
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.RUNNING
# Co(gc): Only updated the seq_lens when all check passes
seq_lens = new_seq_lens
# bigdl-llm change start
# summary: removing block_manager related logic.
# self._allocate(seq_group)
@ -204,7 +212,7 @@ class FixedWindowScheduler:
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled,
prompt_run=True,
num_batched_tokens=num_batched_tokens,
num_batched_tokens=len(seq_lens) * max(seq_lens) if seq_lens else 0,
ignored_seq_groups=ignored_seq_groups,
finished_seqs=finished_seqs,
)
@ -258,6 +266,13 @@ class FixedWindowScheduler:
# summary: The original code free the block in block_manager.
# now, we added it into a list to pass to worker in the next model_execute stage.
self.cleaned.append(seq.seq_id)
for i in range(len(self.kv_cache)):
for j in range(2):
if not self.kv_cache[i][j].get(seq.seq_id) is None:
del self.kv_cache[i][j][seq.seq_id]
# del self.kv_cache[seq.seq_id]
# logger.info(f"freed seqs: {seq.seq_id} .
# now kv cache is: {list(self.kv_cache[0][0].keys())} ")
# bigdl-llm change end
def free_finished_seq_groups(self) -> None:

View file

@ -243,15 +243,16 @@ class _AsyncLLMEngine(LLMEngine):
) -> Any:
"""Runs the given method on all workers."""
# bigdl-llm change start
all_outputs = []
coros = []
for worker in self.workers:
# if self.parallel_config.worker_use_ray:
# executor = partial(worker.execute_method.remote, method)
# else:
executor = getattr(worker, method)
coros.append(asyncio.get_event_loop().run_in_executor(
None, partial(executor, *args, **kwargs)))
output = executor(*args, **kwargs)
all_outputs.append(output)
all_outputs = await asyncio.gather(*coros)
# if self.parallel_config.worker_use_ray:
# all_outputs = await asyncio.gather(*all_outputs)

View file

@ -36,7 +36,7 @@
#
import time
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, Dict
from bigdl.llm.vllm.config import ModelConfig, SchedulerConfig
from bigdl.llm.vllm.core.scheduler import SchedulerOutputs, FixedWindowScheduler
@ -127,6 +127,7 @@ class LLMEngine:
# self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
self.kv_cache = [[dict() for _ in range(2)] for _ in range(32)]
# self._verify_args()
self.tokenizer = get_tokenizer(
@ -142,7 +143,7 @@ class LLMEngine:
self._init_workers()
# Co(gc): we create a fixed scheduler
self.scheduler = FixedWindowScheduler(scheduler_config)
self.scheduler = FixedWindowScheduler(scheduler_config, kv_cache=self.kv_cache)
# Logging.
self.last_logging_time = 0.0
@ -170,6 +171,7 @@ class LLMEngine:
self.scheduler_config,
0,
# distributed_init_method,
kv_cache=self.kv_cache
)
self.workers.append(worker)
self._run_workers(

View file

@ -23,9 +23,9 @@ from typing import Optional, Tuple, List, Type, Dict
from bigdl.llm.vllm.sequence import SequenceOutputs, SequenceGroupMetadata
from bigdl.llm.vllm.model_executor.layers.bigdl_sampler import BigDLSampler
from bigdl.llm.vllm.model_executor.models.bigdl_model import BigDLModelForCausalLM
from bigdl.llm.vllm.logger import init_logger
import math
import time
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
@ -35,6 +35,9 @@ from transformers.generation.logits_process import (
)
logger = init_logger(__name__)
def _pad_to_max(x: List[int], max_len: int, padding_id: int = 0) -> List[int]:
return x + [padding_id] * (max_len - len(x))
@ -87,7 +90,6 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
else:
self.device = torch.device(device)
self.dtype = self.model.dtype
self.kv_cache_size = [0]
self.last_seq_ids = []
self.tmp_kv_cache = None
self.pad_token_id = config.pad_token_id
@ -170,15 +172,25 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
# "return_dict": True,
}
# pdb.set_trace()
if self.device.type == 'xpu':
torch.xpu.empty_cache()
st_timestamp = time.perf_counter()
outputs = self.model.forward(**kwargs)
# tmp = torch.xpu.memory_stats()
# logger.info(f"0: {tmp['allocated_bytes.all.current']}")
# self.last_seq_ids = cur_seq_ids[:]
# self.last_kv_cache = outputs.past_key_values
self._set_last_seq_ids(cur_seq_ids[:])
self._set_last_kv_cache(outputs.past_key_values)
self.last_seq_ids = cur_seq_ids[:]
self.tmp_kv_cache = outputs.past_key_values
logits = outputs.logits[:, -1, :]
bigdl_output = self.sampler(logits, input_metadata, st_timestamp)
# tmp = torch.xpu.memory_stats()
# logger.info(f"before: {tmp['allocated_bytes.all.current']}")
self.update_kv_cache(cur_seq_ids, outputs.past_key_values,
self.update_kv_cache(cur_seq_ids,
kv_cache, kv_cache_size_0, kv_cache_size_1)
# tmp = torch.xpu.memory_stats()
# logger.info(f"after: {tmp['allocated_bytes.all.current']}")
return bigdl_output

View file

@ -20,21 +20,31 @@ from typing import Optional, Tuple, List, Type, Dict
from transformers import LlamaConfig
from bigdl.llm.vllm.sequence import SequenceOutputs, SequenceGroupMetadata
from bigdl.llm.transformers.models.utils import extend_kv_cache
zero_cache_dict = {}
def get_zero_tensor(length, cur_size, device, pos):
if length not in zero_cache_dict:
tmp_size = cur_size[:]
tmp_size[pos] = length
zero_cache_dict[length] = torch.zeros(tmp_size, device=device)
return zero_cache_dict[length].narrow(pos, 0, length - cur_size[pos])
def _pad_kv_cache_view(t: torch.Tensor, len: int,
device: torch.device, pos: int = 2) -> torch.Tensor:
cur_size = list(t.size())
if cur_size[pos] < len:
tmp_size = cur_size[:]
tmp_size[pos] = len - cur_size[pos]
zeros = torch.zeros(tmp_size, device=device)
zeros = get_zero_tensor(len, cur_size, device, pos)
padded_view = torch.cat((zeros, t), dim=pos)
return padded_view
if cur_size[pos] > len:
elif cur_size[pos] > len:
padded_view = t.narrow(pos, cur_size[pos] - len, len)
return padded_view
return t
else:
return t
class BigDLModelForCausalLM(nn.Module):
@ -52,10 +62,23 @@ class BigDLModelForCausalLM(nn.Module):
"cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
if device == 'xpu':
try:
import intel_extension_for_pytorch as ipex
except ImportError:
print("Intel Extension for PyTorch is not installed, \
but is required for xpu inference.")
self.max_seq_limit = max_model_len
self.last_kv_cache = None
self.last_seq_ids = None
def _set_last_kv_cache(self, last_kv_cache):
self.last_kv_cache = last_kv_cache
def _set_last_seq_ids(self, last_seq_ids):
self.last_seq_ids = last_seq_ids
# This is an implementation for models that KV Cache shape in (batch_size, num_heads,
# sequence_length, embed_size_per_head).
def prepare_kv_cache(
@ -69,39 +92,41 @@ class BigDLModelForCausalLM(nn.Module):
max_seq_limit = self.max_seq_limit
if (self.last_kv_cache is not None) and cur_seq_ids == self.last_seq_ids:
if self.last_kv_cache[0][0].size(2) < max_seq_limit:
bigdl_kv_cache = self.tmp_kv_cache
bigdl_kv_cache = self.last_kv_cache
else:
bigdl_kv_cache = [[tmp.narrow(2, self.last_kv_cache[0][0].size(2)
- max_seq_limit, max_seq_limit)
for tmp in tmp_list] for tmp_list in self.last_kv_cache]
del self.last_kv_cache
else:
del self.last_kv_cache
bigdl_kv_cache = []
for i in range(kv_cache_size_0):
cur_list = []
for j in range(kv_cache_size_1):
cur_view = None
views = []
max_len = 0
for seq_group_meta_data in seq_group_meta_data_lists:
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
seq_data = seq_group_meta_data.seq_data[seq_id]
view_size = [1] + list(kv_cache[seq_id][i][j].shape)
if cur_view is None:
cur_view = kv_cache[seq_id][i][j].view(view_size)
else:
if cur_view.size(2) != view_size[2]:
max_len = max(cur_view.size(2), view_size[2])
cur_view = _pad_kv_cache_view(cur_view, max_len, self.device)
tmp_view = _pad_kv_cache_view(
kv_cache[seq_id][i][j].view(view_size),
max_len, self.device)
cur_view = torch.cat((cur_view, tmp_view), dim=0)
else:
cur_view = torch.cat(
(cur_view, kv_cache[seq_id][i][j].view(view_size)), dim=0)
if cur_view.size(2) > max_seq_limit:
view_size = [1] + list(kv_cache[i][j][seq_id].shape)
views.append(kv_cache[i][j][seq_id].view(view_size))
max_len = max(max_len, view_size[2])
views = [_pad_kv_cache_view(v, max_len, self.device) for v in views]
cur_view = torch.cat(views, dim=0)
if cur_view.size(2) > max_seq_limit * 1.5:
cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
cur_list.append(cur_view)
for seq_group_meta_data in seq_group_meta_data_lists:
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
del kv_cache[i][j][seq_id]
bigdl_kv_cache.append(cur_list)
return bigdl_kv_cache
# This is an implementation for models that KV Cache shape in (batch_size, num_heads,
@ -109,20 +134,16 @@ class BigDLModelForCausalLM(nn.Module):
def update_kv_cache(
self,
cur_seq_ids: List[int],
past_key_values: List[List[torch.Tensor]],
kv_cache: Dict,
kv_cache,
kv_cache_size_0: int,
kv_cache_size_1: int,
) -> None:
index = 0
for seq_id in cur_seq_ids:
if kv_cache.get(seq_id) is None:
kv_cache[seq_id] = [[[] for _ in range(kv_cache_size_1)]
for _ in range(kv_cache_size_0)]
for i in range(kv_cache_size_0):
for j in range(kv_cache_size_1):
kv_cache[seq_id][i][j] = past_key_values[i][j][index]
index = index + 1
for i in range(kv_cache_size_0):
for j in range(kv_cache_size_1):
index = 0
for seq_id in cur_seq_ids:
kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][index]
index = index + 1
def forward(
self,

View file

@ -68,6 +68,7 @@ class Worker:
scheduler_config: SchedulerConfig,
rank: Optional[int] = None,
# distributed_init_method: Optional[str] = None,
kv_cache: Optional[Dict] = None,
) -> None:
self.model_config = model_config
# self.parallel_config = parallel_config
@ -84,17 +85,18 @@ class Worker:
self.cache_events = None
self.gpu_cache = None
self.kv_cache = dict()
self.kv_cache = kv_cache
def clean_finished_seqs(self, finished_seqs: List[int]):
"""
This function cleans the finished sequences and their KVCache in self.kv_cache
"""
for seq_id in finished_seqs:
if seq_id not in self.kv_cache.keys():
warnings.warn(f"Duplicate key {seq_id} received during clean worker's KVCache")
continue
del self.kv_cache[seq_id]
pass
# for seq_id in finished_seqs:
# if seq_id not in self.kv_cache.keys():
# # warnings.warn(f"Duplicate key {seq_id} received during clean worker's KVCache")
# continue
# del self.kv_cache[seq_id]
def init_model(self):
if self.model_config.device == 'gpu':
@ -282,6 +284,10 @@ class Worker:
if finished_seqs:
self.clean_finished_seqs(finished_seqs)
# if self.model_config.device == 'xpu':
# import intel_extension_for_pytorch as ipex
# torch.xpu.empty_cache()
cache_events = None
# If there is no input, we don't need to execute the model.
if not seq_group_metadata_list: