From 0ea842231e881401a909ae5618c3334a4e59b65d Mon Sep 17 00:00:00 2001 From: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com> Date: Tue, 26 Dec 2023 16:03:57 +0800 Subject: [PATCH] [LLM] vLLM: Add api_server entrypoint (#9783) Add vllm.entrypoints.api_server for benchmark_serving.py in vllm. --- .../bigdl/llm/vllm/entrypoints/api_server.py | 123 ++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 python/llm/src/bigdl/llm/vllm/entrypoints/api_server.py diff --git a/python/llm/src/bigdl/llm/vllm/entrypoints/api_server.py b/python/llm/src/bigdl/llm/vllm/entrypoints/api_server.py new file mode 100644 index 00000000..6dcdf7ec --- /dev/null +++ b/python/llm/src/bigdl/llm/vllm/entrypoints/api_server.py @@ -0,0 +1,123 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/entrypoints/api_server.py +# which is licensed under Apache License 2.0 +# +# Copyright 2023 The vLLM team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +from typing import AsyncGenerator + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse +import uvicorn + +from bigdl.llm.vllm.engine.arg_utils import AsyncEngineArgs +from bigdl.llm.vllm.engine.async_llm_engine import AsyncLLMEngine +from bigdl.llm.vllm.sampling_params import SamplingParams +from bigdl.llm.vllm.utils import random_uuid + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds. +app = FastAPI() +engine = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + request_dict = await request.json() + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + sampling_params = SamplingParams(**request_dict) + request_id = random_uuid() + + results_generator = engine.generate(prompt, sampling_params, request_id) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\0").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + async for request_output in results_generator: + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) + return Response(status_code=499) + final_output = request_output + + assert final_output is not None # noqa + prompt = final_output.prompt + text_outputs = [prompt + output.text for output in final_output.outputs] + ret = {"text": text_outputs} + return JSONResponse(ret) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) + + uvicorn.run(app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile)