diff --git a/python/llm/src/bigdl/llm/serving/README.md b/python/llm/src/bigdl/llm/serving/README.md index e967f3da..4d0b3540 100644 --- a/python/llm/src/bigdl/llm/serving/README.md +++ b/python/llm/src/bigdl/llm/serving/README.md @@ -66,6 +66,13 @@ Wait until the process finishes loading the model and you see "Uvicorn running o > To run model worker using Intel GPU, simple change the --device cpu option to --device xpu +We also provide the `vllm_worker` which uses the [vLLM](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/vLLM-Serving) engine for better hardware utilization. + +To run using the `vllm_worker`, just simply uses the following command: +```bash +python3 -m bigdl.llm.serving.vllm_worker --model-path meta-llama/Llama-2-7b-chat-hf --device cpu/xpu # based on your device +``` + ###### Launch the Gradio web server ```bash diff --git a/python/llm/src/bigdl/llm/serving/vllm_worker.py b/python/llm/src/bigdl/llm/serving/vllm_worker.py new file mode 100644 index 00000000..780ca2ae --- /dev/null +++ b/python/llm/src/bigdl/llm/serving/vllm_worker.py @@ -0,0 +1,263 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified from vllm_worker +# https://github.com/lm-sys/FastChat/blob/v0.2.28/fastchat/serve/vllm_worker.py + +""" +A model worker that executes the model based on vLLM. + +See documentations at docs/vllm_integration.md +""" + +import argparse +import asyncio +import json +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import torch +import uvicorn +# from vllm import AsyncLLMEngine +# from vllm.engine.arg_utils import AsyncEngineArgs +# from vllm.sampling_params import SamplingParams +# from vllm.utils import random_uuid +from bigdl.llm.vllm.engine.async_llm_engine import AsyncLLMEngine +from bigdl.llm.vllm.engine.arg_utils import AsyncEngineArgs +from bigdl.llm.vllm.sampling_params import SamplingParams +from bigdl.llm.vllm.utils import random_uuid + +import numpy as np + +from fastchat.serve.model_worker import ( + BaseModelWorker, + logger, + worker_id, +) +from fastchat.utils import get_context_length + + +app = FastAPI() + + +class VLLMWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + llm_engine: AsyncLLMEngine, + conv_template: str, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker." + ) + self.tokenizer = llm_engine.engine.tokenizer + self.context_len = get_context_length(llm_engine.engine.model_config.hf_config) + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + context = params.pop("prompt") + request_id = params.pop("request_id") + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_new_tokens = params.get("max_new_tokens", 256) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if self.tokenizer.eos_token_id is not None: + stop_token_ids.append(self.tokenizer.eos_token_id) + echo = params.get("echo", True) + ignore_eos = params.get('ignore_eos', False) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + stop.add(self.tokenizer.decode(tid)) + + # make sampling params in vllm + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + sampling_params = SamplingParams( + n=1, + temperature=temperature, + top_p=top_p, + use_beam_search=False, + stop=list(stop), + max_tokens=max_new_tokens, + ignore_eos=ignore_eos, + ) + results_generator = engine.generate(context, sampling_params, request_id) + + async for request_output in results_generator: + prompt = request_output.prompt + if echo: + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + else: + text_outputs = [output.text for output in request_output.outputs] + text_outputs = " ".join(text_outputs) + finish_reason = request_output.outputs[0].finish_reason + output_token_latency = request_output.outputs[0].output_token_latency + first_token_latency = output_token_latency[0] + if len(output_token_latency) > 1: + rest_token_time = np.mean(output_token_latency[1:]) + else: + rest_token_time = None + # Note: usage is not supported yet + ret = {"text": text_outputs, "error_code": 0, "usage": {}, + "finish_reason": finish_reason, "first_token_time": first_token_latency, + "rest_token_time": rest_token_time} + yield (json.dumps(ret) + "\0").encode() + + async def generate(self, params): + async for x in self.generate_stream(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(request_id): + async def abort_request() -> None: + await engine.abort(request_id) + + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + background_tasks.add_task(abort_request) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = random_uuid() + params["request_id"] = request_id + generator = worker.generate_stream(params) + background_tasks = create_background_tasks(request_id) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = random_uuid() + params["request_id"] = request_id + output = await worker.generate(params) + release_worker_semaphore() + await engine.abort(request_id) + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.3") + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=1024) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + if args.model_path: + args.model = args.model_path + # if args.num_gpus > 1: + # args.tensor_parallel_size = args.num_gpus + + # By default, we are creating a CPU asyncEngineArgs. + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) + worker = VLLMWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + args.limit_worker_concurrency, + args.no_register, + engine, + args.conv_template, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info")