Add vLLM[xpu] related code (#10779)

* Add ipex-llm side change

* add runable offline_inference

* refactor to call vllm2

* Verified async server

* add new v2 example

* add README

* fix

* change dir

* refactor readme.md

* add experimental

* fix
This commit is contained in:
Guancheng Fu 2024-04-18 15:29:20 +08:00 committed by GitHub
parent 053ec30737
commit cbe7b5753f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 860 additions and 17 deletions

View file

@ -2,7 +2,9 @@
This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with IPEX-LLM low-bits optimizations).
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.2.1.post1).
Currently, we provide two different versions of vLLM, which are vLLM-v1 and vLLM-v2. vLLM-v1 will be deprecated soon and we recommend you to try vLLM-v2 directly.
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.3.3).
## Example: Serving LLaMA2-7B using Intel GPU
@ -24,9 +26,11 @@ sycl-ls
[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics 3.0 [23.17.26241.33]
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26241]
```
### vLLM-v1 (Deprecated)
<details>
<summary>Details</summary>
### 1. Install
#### 1. Install
To run vLLM continuous batching on Intel GPUs, install the dependencies as follows:
```bash
@ -43,17 +47,16 @@ pip3 install fastapi
pip3 install "uvicorn[standard]"
pip3 install "pydantic<2" # Required for OpenAI server.
```
### 2. Configure recommended environment variables
#### 2. Configure recommended environment variables
```bash
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
```
### 3. Offline inference/Service
#### 3. Offline inference/Service
#### Offline inference
##### Offline inference
To run offline inference using vLLM for a quick impression, use the following example:
@ -65,21 +68,24 @@ To run offline inference using vLLM for a quick impression, use the following ex
python offline_inference.py
```
#### Service
##### Service
To fully utilize the continuous batching feature of the `vLLM`, you can send requests to the service using curl or other similar methods. The requests sent to the engine will be batched at token level. Queries will be executed in the same `forward` step of the LLM and be removed when they are finished instead of waiting for all sequences to be finished.
For vLLM-v1, you can start the service using the following command:
```bash
#!/bin/bash
# You may also want to adjust the `--max-num-batched-tokens` argument, it indicates the hard limit
# of batched prompt length the server will accept
python -m ipex_llm.vllm.entrypoints.openai.api_server \
--model /MODEL_PATH/Llama-2-7b-chat-hf/ --port 8000 \
--load-format 'auto' --device xpu --dtype bfloat16 \
--load-format 'auto' --device xpu --dtype float16 \
--load-in-low-bit sym_int4 \
--max-num-batched-tokens 4096
```
Then you can access the api server as follows:
```bash
@ -94,18 +100,102 @@ Then you can access the api server as follows:
}' &
```
### 4. (Optional) Add a new model
#### 4. (Optional) Add a new model for vLLM-v1
Currently we have only supported LLaMA family model (including `llama`, `vicuna`, `llama-2`, etc.). To use aother model, you may need add some adaptions.
#### 4.1 Add model code
##### 4.1 Add model code
Create or clone the Pytorch model code to `IPEX/python/llm/src/ipex/llm/vllm/model_executor/models`.
#### 4.2 Rewrite the forward methods
##### 4.2 Rewrite the forward methods
Refering to `IPEX/python/llm/src/ipex/llm/vllm/model_executor/models/ipex_llama.py`, it's necessary to maintain a `kv_cache`, which is a nested list of dictionary that maps `req_id` to a three-dimensional tensor **(the structure may vary from models)**. Before the model's actual `forward` method, you could prepare a `past_key_values` according to current `req_id`, and after that you need to update the `kv_cache` with `output.past_key_values`. The clearence will be executed when the request is finished.
#### 4.3 Register new model
##### 4.3 Register new model
Finally, register your `*ForCausalLM` class to the _MODEL_REGISTRY in `IPEX/python/llm/src/ipex/llm/vllm/model_executor/model_loader.py`.
</details>
### vLLM-v2 (experimental support)
Currently, for vLLM-v2, we support the following models:
- Qwen series models
- Llama series models
- ChatGLM series models
- Baichuan series models
#### 1. Install
Install the dependencies for vLLM-v2 as follows:
```bash
# First create an conda environment
conda create -n ipex-vllm python=3.11
conda activate ipex-vllm
# Install dependencies
pip install --pre --upgrade "ipex-llm[xpu]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
# cd to your workdir
git clone -b sycl_xpu https://github.com/analytics-zoo/vllm.git
cd vllm
pip install -r requirements-xpu.txt
pip install --no-deps xformers
VLLM_BUILD_XPU_OPS=1 pip install --no-build-isolation -v -e .
pip install outlines==0.0.34 --no-deps
pip install interegular cloudpickle diskcache joblib lark nest-asyncio numba scipy
# For Qwen model support
pip install transformers_stream_generator einops tiktoken
```
#### 2. Configure recommended environment variables
```bash
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
```
#### 3. Offline inference/Service
##### Offline inference
To run offline inference using vLLM for a quick impression, use the following example:
```bash
#!/bin/bash
# Please first modify the MODEL_PATH in offline_inference.py
# Modify load_in_low_bit to use different quantization dtype
python offline_inference_v2.py
```
##### Service
To fully utilize the continuous batching feature of the `vLLM`, you can send requests to the service using curl or other similar methods. The requests sent to the engine will be batched at token level. Queries will be executed in the same `forward` step of the LLM and be removed when they are finished instead of waiting for all sequences to be finished.
For vLLM-v2, you can start the service using the following command:
```bash
python -m ipex_llm.vllm2.entrypoints.openai.api_server \
--model /MODEL_PATH/Llama-2-7b-chat-hf/ --port 8000 \
--device xpu --dtype float16 \
--load-in-low-bit sym_int4 \
--max-num-batched-tokens 4096
```
Then you can access the api server as follows:
```bash
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "/MODEL_PATH/Llama-2-7b-chat-hf/",
"prompt": "San Francisco is a",
"max_tokens": 128,
"temperature": 0
}' &
```

View file

@ -0,0 +1,60 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/examples/offline_inference.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm import SamplingParams
from ipex_llm.vllm2.engine import IPEXLLMClass as LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="YOUR_MODEL",
device="xpu",
dtype="float16",
enforce_eager=True,
load_in_low_bit="sym_int4")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View file

@ -230,9 +230,11 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
invalidInputError(isinstance(model, torch.nn.Module),
"model should be an instance of "
f"`torch.nn.Module`, but got {type(model)} at last.")
invalidInputError(model.device.type in ('cpu', 'meta'),
"Expect model on device `cpu` or `meta`, "
f"but got device type {model.device.type}")
# To adapt vLLM models
if hasattr(model, 'device'):
invalidInputError(model.device.type in ('cpu', 'meta'),
"Expect model on device `cpu` or `meta`, "
f"but got device type {model.device.type}")
if kwargs.pop("replace_embedding", False):
warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
" please use cpu_embedding instead.", FutureWarning)

View file

@ -59,6 +59,18 @@ def is_auto_awq_available():
return importlib.util.find_spec("awq") is not None
def is_vllm_available():
return importlib.util.find_spec("vllm") is not None
def is_torch_distributed_initialized():
return torch.distributed.is_initialized()
def is_module_in_classes(module, classes):
return any(isinstance(module, cls) for cls in classes)
def is_deepspeed_available():
spec = importlib.util.find_spec("deepspeed")
if spec is not None:
@ -88,7 +100,22 @@ def is_linear_module(module):
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
if is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld):
if is_vllm_available():
# TODO: add tensor parallel feature later
from vllm.model_executor.layers.linear import (
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
)
VLLM_LINEAR_LIST = [
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
]
if is_module_in_classes(module, VLLM_LINEAR_LIST):
in_features = module.input_size
out_features = module.output_size
result = True
mp_group = None
else:
result = False
elif is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld):
in_features = module.infeatures
out_features = module.outfeatures
mp_group = None

View file

@ -0,0 +1,15 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View file

@ -0,0 +1,21 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass
__all__ = [
"IPEXLLMAsyncLLMEngine",
"IPEXLLMLLMEngine",
"IPEXLLMClass",
]

View file

@ -0,0 +1,145 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional, Union
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.entrypoints.llm import LLM
from vllm.utils import Counter
from ipex_llm.vllm2.model_convert import _ipex_llm_convert
from ipex_llm.utils.common import invalidInputError
class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
load_in_low_bit: str = "sym_int4",
# ipex_llm_optimize_mode: str = 'NATIVE',
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Enable ipex-llm optimizations
_ipex_llm_convert(load_in_low_bit)
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
else:
invalidInputError(parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1."))
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop)
return engine
class IPEXLLMClass(LLM):
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
load_in_low_bit: str = "sym_int4",
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs,
)
self.llm_engine = IPEXLLMLLMEngine.from_engine_args(engine_args,
load_in_low_bit=load_in_low_bit)
self.request_counter = Counter()
class IPEXLLMLLMEngine(LLMEngine):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
load_in_low_bit: str = "sym_int4",
# ipex_llm_optimize_mode: str = 'NATIVE',
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
_ipex_llm_convert(load_in_low_bit)
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster and specify the executor class.
if parallel_config.worker_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
else:
invalidInputError(parallel_config.world_size == 1,
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor
# Create the LLM engine.
engine = cls(*engine_configs,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats)
return engine

View file

@ -0,0 +1,284 @@
import argparse
import asyncio
import json
from contextlib import asynccontextmanager
import os
import importlib
import inspect
import ssl
from prometheus_client import make_asgi_app
import fastapi
import uvicorn
from http import HTTPStatus
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse, Response
import vllm
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import (CompletionRequest,
ChatCompletionRequest,
ErrorResponse)
from vllm.logger import init_logger
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import LoRA
from ipex_llm.vllm2.engine import IPEXLLMAsyncLLMEngine
from ipex_llm.utils.common import invalidInputError
TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat: OpenAIServingChat = None
openai_serving_completion: OpenAIServingCompletion = None
logger = init_logger(__name__)
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10)
await engine.do_log_stats()
if not engine_args.disable_log_stats:
asyncio.create_task(_force_log())
yield
app = fastapi.FastAPI(lifespan=lifespan)
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRA(name, path))
setattr(namespace, self.dest, lora_list)
def parse_args():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--uvicorn-log-level",
type=str,
default="info",
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
help="log level for uvicorn")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--api-key",
type=str,
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser.add_argument(
"--lora-modules",
type=str,
default=None,
nargs='+',
action=LoRAParserAction,
help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.")
parser.add_argument("--chat-template",
type=str,
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument("--response-role",
type=str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile",
type=str,
default=None,
help="The file path to the SSL key file")
parser.add_argument("--ssl-certfile",
type=str,
default=None,
help="The file path to the SSL cert file")
parser.add_argument("--ssl-ca-certs",
type=str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--middleware",
type=str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server "
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ")
parser.add_argument(
"--load-in-low-bit",
type=str,
default="sym_int4",
help="Low-bit quantization for IPEX-LLM models")
parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args()
# Add prometheus asgi middleware to route /metrics requests
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
@app.get("/health")
async def health() -> Response:
"""Health check."""
await openai_serving_chat.engine.check_health()
return Response(status_code=200)
@app.get("/v1/models")
async def show_available_models():
models = await openai_serving_chat.show_available_models()
return JSONResponse(content=models.model_dump())
@app.get("/version")
async def show_version():
ver = {"version": vllm.__version__}
return JSONResponse(content=ver)
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
return JSONResponse(content=generator.model_dump())
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
return JSONResponse(content=generator.model_dump())
if __name__ == "__main__":
args = parse_args()
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
token = os.environ.get("VLLM_API_KEY") or args.api_key
if token:
@app.middleware("http")
async def authentication(request: Request, call_next):
if not request.url.path.startswith("/v1"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token:
return JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return await call_next(request)
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
invalidInputError(False, (f"Invalid middleware {middleware}. "
f"Must be a function or a class."))
logger.info(f"vLLM API server version {vllm.__version__}")
logger.info(f"args: {args}")
if args.served_model_name is not None:
served_model = args.served_model_name
else:
served_model = args.model
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = IPEXLLMAsyncLLMEngine.from_engine_args(engine_args)
openai_serving_chat = OpenAIServingChat(engine, served_model,
args.response_role,
args.lora_modules,
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(
engine, served_model, args.lora_modules)
app.root_path = args.root_path
uvicorn.run(app,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)

View file

@ -0,0 +1,199 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from vllm.logger import init_logger
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention
from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention
from vllm.model_executor.model_loader import get_model
from vllm.utils import measure_device_memory
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.utils import measure_device_memory
from vllm.model_executor.input_metadata import InputMetadata
from vllm.config import DeviceConfig
from typing import Tuple
from ipex_llm.utils.common import invalidInputError
def _MLP_forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
def _Attention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output = self.o_proj(attn_output)
return output
def _QWen_Attention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output = self.c_proj(attn_output)
return output
def _QWen_MLP_forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.c_proj(x)
return x
def _ChatGLM_MLP_forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output = self.dense_4h_to_h(intermediate_parallel)
return output
def _Baichuan_Attention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output = self.o_proj(attn_output)
return output
def _ChatGLM_Attention_forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
key_cache, value_cache = kv_cache
context_layer = self.attn(
q,
k,
v,
key_cache,
value_cache,
input_metadata,
)
attn_output = self.dense(context_layer)
return attn_output
_REPLACED_MLP_LAYERS = {
LlamaMLP: _MLP_forward,
Qwen2MLP: _MLP_forward,
BaiChuanMLP: _MLP_forward,
QWenMLP: _QWen_MLP_forward,
GLMMLP: _ChatGLM_MLP_forward
}
_REPLACED_ATTENTION_LAYERS = {
LlamaAttention: _Attention_forward,
Qwen2Attention: _Attention_forward,
QWenAttention: _QWen_Attention_forward,
BaiChuanAttention: _Baichuan_Attention_forward,
GLMAttention: _ChatGLM_Attention_forward
}
def _model_mlp_convert():
for module, replaced_func in _REPLACED_MLP_LAYERS.items():
setattr(module, "forward", replaced_func)
def _model_attention_convert():
for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items():
setattr(module, "forward", replaced_func)
def _ipex_llm_convert(load_in_low_bit):
from vllm.worker.model_runner import ModelRunner
import vllm.model_executor.model_loader as model_loader
setattr(ModelRunner, "load_model", get_load_function(load_in_low_bit))
def get_load_function(low_bit):
def _ipex_llm_load_model(self) -> None:
_model_mlp_convert()
_model_attention_convert()
with measure_device_memory() as m:
# only support xpu for now
# We have to create a new DeviceConfig.
# Otherwise, will get the wrong xpu memory usage
self.model = get_model(self.model_config,
DeviceConfig("cpu"),
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
from ipex_llm import optimize_model
optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype)
self.model = self.model.to(device=self.device_config.device,
dtype=self.model_config.dtype)
self.model_memory_usage = m.consumed_memory
logger = init_logger(__name__)
logger.info(f"Loading model weights took "
f"{self.model_memory_usage / float(2**30):.4f} GB")
if self.lora_config:
invalidInputError(hasattr(self.model, "supported_lora_modules")
and self.model.supported_lora_modules,
"Model does not support LoRA")
invalidInputError(hasattr(self.model, "embedding_modules"),
"Model does not have embedding_modules")
invalidInputError(hasattr(self.model, "embedding_padding_modules"),
"Model does not have embedding_padding_modules")
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens +
self.scheduler_config.max_paddings, self.vocab_size,
self.lora_config, self.device, self.model.embedding_modules,
self.model.embedding_padding_modules)
self.model = self.lora_manager.create_lora_manager(self.model)
return _ipex_llm_load_model