diff --git a/python/llm/src/bigdl/llm/serving/README.md b/python/llm/src/bigdl/llm/serving/fastchat/README.md similarity index 70% rename from python/llm/src/bigdl/llm/serving/README.md rename to python/llm/src/bigdl/llm/serving/fastchat/README.md index e96e946a..c78b1179 100644 --- a/python/llm/src/bigdl/llm/serving/README.md +++ b/python/llm/src/bigdl/llm/serving/fastchat/README.md @@ -11,7 +11,8 @@ BigDL-LLM can be easily integrated into FastChat so that user can use `BigDL-LLM - [Start the service](#start-the-service) - [Launch controller](#launch-controller) - [Launch model worker(s) and load models](#launch-model-workers-and-load-models) - - [BigDL model worker](#bigdl-model-worker) + - [BigDL model worker (deprecated)](#bigdl-model-worker-deprecated) + - [BigDL worker](#bigdl-llm-worker) - [BigDL vLLM model worker](#vllm-model-worker) - [Launch Gradio web server](#launch-gradio-web-server) - [Launch RESTful API server](#launch-restful-api-server) @@ -48,10 +49,14 @@ python3 -m fastchat.serve.controller ### Launch model worker(s) and load models -#### BigDL model worker - Using BigDL-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat. +#### BigDL model worker (deprecated) +
+details + +> Warning: This method has been deprecated, please change to use `BigDL-LLM` [worker](#bigdl-llm-worker) instead. + FastChat determines the Model adapter to use through path matching. Therefore, in order to load models using BigDL-LLM, you need to make some modifications to the model's name. For instance, assuming you have downloaded the `llama-7b-hf` from [HuggingFace](https://huggingface.co/decapoda-research/llama-7b-hf). Then, to use the `BigDL-LLM` as backend, you need to change name from `llama-7b-hf` to `bigdl-7b`.The key point here is that the model's path should include "bigdl" and **should not include paths matched by other model adapters**. @@ -66,10 +71,10 @@ Then we can run model workers ```bash # On CPU -python3 -m bigdl.llm.serving.model_worker --model-path PATH/TO/bigdl-7b --device cpu +python3 -m bigdl.llm.serving.fastchat.model_worker --model-path PATH/TO/bigdl-7b --device cpu # On GPU -python3 -m bigdl.llm.serving.model_worker --model-path PATH/TO/bigdl-7b --device xpu +python3 -m bigdl.llm.serving.fastchat.model_worker --model-path PATH/TO/bigdl-7b --device xpu ``` If you run successfully using `BigDL` backend, you can see the output in log like this: @@ -78,7 +83,28 @@ If you run successfully using `BigDL` backend, you can see the output in log lik INFO - Converting the current model to sym_int4 format...... ``` -> note: We currently only support int4 quantization. +> note: We currently only support int4 quantization for this method. +
+ +#### BigDL-LLM worker +To integrate BigDL-LLM with `FastChat` efficiently, we have provided a new model_worker implementation named `bigdl_worker.py`. + +To run the `bigdl_worker` on CPU, using the following code: +```bash +source bigdl-llm-init -t + +# Available low_bit format including sym_int4, sym_int8, bf16 etc. +python3 -m bigdl.llm.serving.fastchat.bigdl_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "cpu" +``` + + +For GPU example: +```bash +# Available low_bit format including sym_int4, sym_int8, fp16 etc. +python3 -m bigdl.llm.serving.fastcaht.bigdl_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "xpu" +``` + +For a full list of accepted arguments, you can refer to the main method of the `bigdl_worker.py` #### BigDL vLLM model worker @@ -88,10 +114,10 @@ To run using the `vLLM_worker`, we don't need to change model name, just simply ```bash # On CPU -python3 -m bigdl.llm.serving.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device cpu +python3 -m bigdl.llm.serving.fastchat.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device cpu # On GPU -python3 -m bigdl.llm.serving.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device xpu +python3 -m bigdl.llm.serving.fastchat.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device xpu ``` ### Launch Gradio web server diff --git a/python/llm/src/bigdl/llm/serving/fastchat/__init__.py b/python/llm/src/bigdl/llm/serving/fastchat/__init__.py new file mode 100644 index 00000000..2151a805 --- /dev/null +++ b/python/llm/src/bigdl/llm/serving/fastchat/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/llm/src/bigdl/llm/serving/bigdl_llm_model.py b/python/llm/src/bigdl/llm/serving/fastchat/bigdl_llm_model.py similarity index 100% rename from python/llm/src/bigdl/llm/serving/bigdl_llm_model.py rename to python/llm/src/bigdl/llm/serving/fastchat/bigdl_llm_model.py diff --git a/python/llm/src/bigdl/llm/serving/fastchat/bigdl_worker.py b/python/llm/src/bigdl/llm/serving/fastchat/bigdl_worker.py new file mode 100644 index 00000000..35c009f7 --- /dev/null +++ b/python/llm/src/bigdl/llm/serving/fastchat/bigdl_worker.py @@ -0,0 +1,337 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A model worker that executes the model based on BigDL-LLM. +Relies on load_model method +""" + +import argparse +import asyncio +import atexit +import json +from typing import List +import uuid +from threading import Thread +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.concurrency import run_in_threadpool +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn + +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) +from fastchat.serve.base_model_worker import ( + create_background_tasks, + acquire_worker_semaphore, + release_worker_semaphore, +) +from fastchat.utils import get_context_length, is_partial_stop + +from bigdl.llm.transformers.loader import load_model +from transformers import TextIteratorStreamer + +app = FastAPI() + + +class BigDLLLMWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + conv_template: str = None, + load_in_low_bit: str = "sym_int4", + device: str = "cpu", + no_register: bool = False, + trust_remote_code: bool = False, + stream_interval: int = 4, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + self.load_in_low_bit = load_in_low_bit + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}," + f" worker type: BigDLLLM worker..." + ) + + logger.info(f"Using low bit format: {self.load_in_low_bit}, device: {device}") + + self.device = device + + self.model, self.tokenizer = load_model( + model_path, device, self.load_in_low_bit, trust_remote_code + ) + self.stream_interval = stream_interval + self.context_len = get_context_length(self.model.config) + if not no_register: + self.init_heart_beat() + + def generate_stream_gate(self, params): + self.call_ct += 1 + # context length is self.context_length + prompt = params["prompt"] + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", 0)) # 0 means disable + max_new_tokens = int(params.get("max_new_tokens", 256)) + echo = bool(params.get("echo", True)) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if self.tokenizer.eos_token_id not in stop_token_ids: + stop_token_ids.append(self.tokenizer.eos_token_id) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.add(s) + + input_ids = self.tokenizer.encode(prompt, return_tensors="pt") + if self.device == "xpu": + input_ids = input_ids.to("xpu") + + input_echo_len = input_ids.shape[1] + + if self.model.config.is_encoder_decoder: + max_src_len = self.context_len + input_ids = input_ids[:max_src_len] + input_echo_len = len(input_ids) + prompt = self.tokenizer.decode( + input_ids, + skip_special_tokens=True, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + else: + # Truncate the max_new_tokens if input_ids is too long + new_max_new_tokens = min(self.context_len - input_echo_len, max_new_tokens) + if new_max_new_tokens < max_new_tokens: + logger.info( + f"Warning: max_new_tokens[{max_new_tokens}] + prompt[{input_echo_len}] greater " + f"than context_length[{self.context_len}]" + ) + logger.info(f"Reset max_new_tokens to {new_max_new_tokens}") + max_new_tokens = new_max_new_tokens + + # Use TextIteratorStreamer for streaming output + streamer = TextIteratorStreamer( + tokenizer=self.tokenizer, + timeout=60, + skip_prompt=True, + skip_special_tokens=True, + ) + + # Generation config: + # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig + generated_kwargs = dict( + max_new_tokens=max_new_tokens, + streamer=streamer, + temperature=temperature, + repetition_penalty=repetition_penalty, + top_p=top_p, + top_k=top_k, + ) + + def model_generate(): + self.model.generate(input_ids, **generated_kwargs) + + t1 = Thread(target=model_generate) + t1.start() + + stopped = False + finish_reason = None + if echo: + partial_output = prompt + rfind_start = len(prompt) + else: + partial_output = "" + rfind_start = 0 + + for i in range(max_new_tokens): + try: + output_token = next(streamer) + except StopIteration: + # Stop early + stopped = True + break + partial_output += output_token + + if i % self.stream_interval == 0 or i == max_new_tokens - 1 or stopped: + for each_stop in stop: + pos = partial_output.rfind(each_stop, rfind_start) + if pos != -1: + partial_output = partial_output[:pos] + stopped = True + break + else: + partially_stopped = is_partial_stop(partial_output, each_stop) + if partially_stopped: + break + if not partially_stopped: + json_output = { + "text": partial_output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + ret = { + "text": json_output["text"], + "error_code": 0, + } + ret["usage"] = json_output["usage"] + ret["finish_reason"] = json_output["finish_reason"] + yield json.dumps(ret).encode() + b"\0" + + if stopped: + break + else: + finish_reason = "length" + + if stopped: + finish_reason = "stop" + json_output = { + "text": partial_output, + "error_code": 0, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + yield json.dumps(json_output).encode() + b"\0" + + def generate_gate(self, params): + for x in self.generate_stream_gate(params): + # for x in self.generate_stream2(params): + pass + return json.loads(x[:-1].decode()) + + +# Below are api interfaces +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + output = await asyncio.to_thread(worker.generate_gate, params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.5") + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=1024) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument("--stream-interval", type=int, default=4) + parser.add_argument( + "--low-bit", type=str, default="sym_int4", help="Low bit format." + ) + parser.add_argument( + "--device", type=str, default="cpu", help="Device for executing model, cpu/xpu" + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + default=False, + help="Trust remote code (e.g., from HuggingFace) when" + "downloading the model and tokenizer.", + ) + + args = parser.parse_args() + worker = BigDLLLMWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + args.limit_worker_concurrency, + args.conv_template, + args.low_bit, + args.device, + args.no_register, + args.trust_remote_code, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/python/llm/src/bigdl/llm/serving/model_worker.py b/python/llm/src/bigdl/llm/serving/fastchat/model_worker.py similarity index 100% rename from python/llm/src/bigdl/llm/serving/model_worker.py rename to python/llm/src/bigdl/llm/serving/fastchat/model_worker.py diff --git a/python/llm/src/bigdl/llm/serving/vllm_worker.py b/python/llm/src/bigdl/llm/serving/fastchat/vllm_worker.py similarity index 100% rename from python/llm/src/bigdl/llm/serving/vllm_worker.py rename to python/llm/src/bigdl/llm/serving/fastchat/vllm_worker.py