[Serving] Add vllm_worker to fastchat serving framework (#9934)
* add worker * finish * finish * add license * add more comments
This commit is contained in:
parent
a8c866c32b
commit
2e1448f08e
2 changed files with 270 additions and 0 deletions
|
|
@ -66,6 +66,13 @@ Wait until the process finishes loading the model and you see "Uvicorn running o
|
|||
|
||||
> To run model worker using Intel GPU, simple change the --device cpu option to --device xpu
|
||||
|
||||
We also provide the `vllm_worker` which uses the [vLLM](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/vLLM-Serving) engine for better hardware utilization.
|
||||
|
||||
To run using the `vllm_worker`, just simply uses the following command:
|
||||
```bash
|
||||
python3 -m bigdl.llm.serving.vllm_worker --model-path meta-llama/Llama-2-7b-chat-hf --device cpu/xpu # based on your device
|
||||
```
|
||||
|
||||
###### Launch the Gradio web server
|
||||
|
||||
```bash
|
||||
|
|
|
|||
263
python/llm/src/bigdl/llm/serving/vllm_worker.py
Normal file
263
python/llm/src/bigdl/llm/serving/vllm_worker.py
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified from vllm_worker
|
||||
# https://github.com/lm-sys/FastChat/blob/v0.2.28/fastchat/serve/vllm_worker.py
|
||||
|
||||
"""
|
||||
A model worker that executes the model based on vLLM.
|
||||
|
||||
See documentations at docs/vllm_integration.md
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from fastapi import FastAPI, Request, BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
import torch
|
||||
import uvicorn
|
||||
# from vllm import AsyncLLMEngine
|
||||
# from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
# from vllm.sampling_params import SamplingParams
|
||||
# from vllm.utils import random_uuid
|
||||
from bigdl.llm.vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from bigdl.llm.vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from bigdl.llm.vllm.sampling_params import SamplingParams
|
||||
from bigdl.llm.vllm.utils import random_uuid
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastchat.serve.model_worker import (
|
||||
BaseModelWorker,
|
||||
logger,
|
||||
worker_id,
|
||||
)
|
||||
from fastchat.utils import get_context_length
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class VLLMWorker(BaseModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
worker_id: str,
|
||||
model_path: str,
|
||||
model_names: List[str],
|
||||
limit_worker_concurrency: int,
|
||||
no_register: bool,
|
||||
llm_engine: AsyncLLMEngine,
|
||||
conv_template: str,
|
||||
):
|
||||
super().__init__(
|
||||
controller_addr,
|
||||
worker_addr,
|
||||
worker_id,
|
||||
model_path,
|
||||
model_names,
|
||||
limit_worker_concurrency,
|
||||
conv_template,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker."
|
||||
)
|
||||
self.tokenizer = llm_engine.engine.tokenizer
|
||||
self.context_len = get_context_length(llm_engine.engine.model_config.hf_config)
|
||||
|
||||
if not no_register:
|
||||
self.init_heart_beat()
|
||||
|
||||
async def generate_stream(self, params):
|
||||
self.call_ct += 1
|
||||
|
||||
context = params.pop("prompt")
|
||||
request_id = params.pop("request_id")
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
max_new_tokens = params.get("max_new_tokens", 256)
|
||||
stop_str = params.get("stop", None)
|
||||
stop_token_ids = params.get("stop_token_ids", None) or []
|
||||
if self.tokenizer.eos_token_id is not None:
|
||||
stop_token_ids.append(self.tokenizer.eos_token_id)
|
||||
echo = params.get("echo", True)
|
||||
ignore_eos = params.get('ignore_eos', False)
|
||||
|
||||
# Handle stop_str
|
||||
stop = set()
|
||||
if isinstance(stop_str, str) and stop_str != "":
|
||||
stop.add(stop_str)
|
||||
elif isinstance(stop_str, list) and stop_str != []:
|
||||
stop.update(stop_str)
|
||||
|
||||
for tid in stop_token_ids:
|
||||
if tid is not None:
|
||||
stop.add(self.tokenizer.decode(tid))
|
||||
|
||||
# make sampling params in vllm
|
||||
top_p = max(top_p, 1e-5)
|
||||
if temperature <= 1e-5:
|
||||
top_p = 1.0
|
||||
sampling_params = SamplingParams(
|
||||
n=1,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
use_beam_search=False,
|
||||
stop=list(stop),
|
||||
max_tokens=max_new_tokens,
|
||||
ignore_eos=ignore_eos,
|
||||
)
|
||||
results_generator = engine.generate(context, sampling_params, request_id)
|
||||
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
if echo:
|
||||
text_outputs = [
|
||||
prompt + output.text for output in request_output.outputs
|
||||
]
|
||||
else:
|
||||
text_outputs = [output.text for output in request_output.outputs]
|
||||
text_outputs = " ".join(text_outputs)
|
||||
finish_reason = request_output.outputs[0].finish_reason
|
||||
output_token_latency = request_output.outputs[0].output_token_latency
|
||||
first_token_latency = output_token_latency[0]
|
||||
if len(output_token_latency) > 1:
|
||||
rest_token_time = np.mean(output_token_latency[1:])
|
||||
else:
|
||||
rest_token_time = None
|
||||
# Note: usage is not supported yet
|
||||
ret = {"text": text_outputs, "error_code": 0, "usage": {},
|
||||
"finish_reason": finish_reason, "first_token_time": first_token_latency,
|
||||
"rest_token_time": rest_token_time}
|
||||
yield (json.dumps(ret) + "\0").encode()
|
||||
|
||||
async def generate(self, params):
|
||||
async for x in self.generate_stream(params):
|
||||
pass
|
||||
return json.loads(x[:-1].decode())
|
||||
|
||||
|
||||
def release_worker_semaphore():
|
||||
worker.semaphore.release()
|
||||
|
||||
|
||||
def acquire_worker_semaphore():
|
||||
if worker.semaphore is None:
|
||||
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
|
||||
return worker.semaphore.acquire()
|
||||
|
||||
|
||||
def create_background_tasks(request_id):
|
||||
async def abort_request() -> None:
|
||||
await engine.abort(request_id)
|
||||
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(release_worker_semaphore)
|
||||
background_tasks.add_task(abort_request)
|
||||
return background_tasks
|
||||
|
||||
|
||||
@app.post("/worker_generate_stream")
|
||||
async def api_generate_stream(request: Request):
|
||||
params = await request.json()
|
||||
await acquire_worker_semaphore()
|
||||
request_id = random_uuid()
|
||||
params["request_id"] = request_id
|
||||
generator = worker.generate_stream(params)
|
||||
background_tasks = create_background_tasks(request_id)
|
||||
return StreamingResponse(generator, background=background_tasks)
|
||||
|
||||
|
||||
@app.post("/worker_generate")
|
||||
async def api_generate(request: Request):
|
||||
params = await request.json()
|
||||
await acquire_worker_semaphore()
|
||||
request_id = random_uuid()
|
||||
params["request_id"] = request_id
|
||||
output = await worker.generate(params)
|
||||
release_worker_semaphore()
|
||||
await engine.abort(request_id)
|
||||
return JSONResponse(output)
|
||||
|
||||
|
||||
@app.post("/worker_get_status")
|
||||
async def api_get_status(request: Request):
|
||||
return worker.get_status()
|
||||
|
||||
|
||||
@app.post("/count_token")
|
||||
async def api_count_token(request: Request):
|
||||
params = await request.json()
|
||||
return worker.count_token(params)
|
||||
|
||||
|
||||
@app.post("/worker_get_conv_template")
|
||||
async def api_get_conv(request: Request):
|
||||
return worker.get_conv_template()
|
||||
|
||||
|
||||
@app.post("/model_details")
|
||||
async def api_model_details(request: Request):
|
||||
return {"context_length": worker.context_len}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=21002)
|
||||
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
||||
parser.add_argument(
|
||||
"--controller-address", type=str, default="http://localhost:21001"
|
||||
)
|
||||
parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.3")
|
||||
parser.add_argument(
|
||||
"--model-names",
|
||||
type=lambda s: s.split(","),
|
||||
help="Optional display comma separated names",
|
||||
)
|
||||
parser.add_argument("--limit-worker-concurrency", type=int, default=1024)
|
||||
parser.add_argument("--no-register", action="store_true")
|
||||
parser.add_argument("--num-gpus", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--conv-template", type=str, default=None, help="Conversation prompt template."
|
||||
)
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
if args.model_path:
|
||||
args.model = args.model_path
|
||||
# if args.num_gpus > 1:
|
||||
# args.tensor_parallel_size = args.num_gpus
|
||||
|
||||
# By default, we are creating a CPU asyncEngineArgs.
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
worker = VLLMWorker(
|
||||
args.controller_address,
|
||||
args.worker_address,
|
||||
worker_id,
|
||||
args.model_path,
|
||||
args.model_names,
|
||||
args.limit_worker_concurrency,
|
||||
args.no_register,
|
||||
engine,
|
||||
args.conv_template,
|
||||
)
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
Loading…
Reference in a new issue