Add FastChat bigdl_worker (#10493)

* done

* fix format

* add licence

* done

* fix doc

* refactor folder

* add license
This commit is contained in:
Guancheng Fu 2024-03-21 18:35:05 +08:00 committed by GitHub
parent 8d0ea1b9b3
commit 3a3756b51d
6 changed files with 386 additions and 8 deletions

View file

@ -11,7 +11,8 @@ BigDL-LLM can be easily integrated into FastChat so that user can use `BigDL-LLM
- [Start the service](#start-the-service) - [Start the service](#start-the-service)
- [Launch controller](#launch-controller) - [Launch controller](#launch-controller)
- [Launch model worker(s) and load models](#launch-model-workers-and-load-models) - [Launch model worker(s) and load models](#launch-model-workers-and-load-models)
- [BigDL model worker](#bigdl-model-worker) - [BigDL model worker (deprecated)](#bigdl-model-worker-deprecated)
- [BigDL worker](#bigdl-llm-worker)
- [BigDL vLLM model worker](#vllm-model-worker) - [BigDL vLLM model worker](#vllm-model-worker)
- [Launch Gradio web server](#launch-gradio-web-server) - [Launch Gradio web server](#launch-gradio-web-server)
- [Launch RESTful API server](#launch-restful-api-server) - [Launch RESTful API server](#launch-restful-api-server)
@ -48,10 +49,14 @@ python3 -m fastchat.serve.controller
### Launch model worker(s) and load models ### Launch model worker(s) and load models
#### BigDL model worker
Using BigDL-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat. Using BigDL-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat.
#### BigDL model worker (deprecated)
<details>
<summary>details</summary>
> Warning: This method has been deprecated, please change to use `BigDL-LLM` [worker](#bigdl-llm-worker) instead.
FastChat determines the Model adapter to use through path matching. Therefore, in order to load models using BigDL-LLM, you need to make some modifications to the model's name. FastChat determines the Model adapter to use through path matching. Therefore, in order to load models using BigDL-LLM, you need to make some modifications to the model's name.
For instance, assuming you have downloaded the `llama-7b-hf` from [HuggingFace](https://huggingface.co/decapoda-research/llama-7b-hf). Then, to use the `BigDL-LLM` as backend, you need to change name from `llama-7b-hf` to `bigdl-7b`.The key point here is that the model's path should include "bigdl" and **should not include paths matched by other model adapters**. For instance, assuming you have downloaded the `llama-7b-hf` from [HuggingFace](https://huggingface.co/decapoda-research/llama-7b-hf). Then, to use the `BigDL-LLM` as backend, you need to change name from `llama-7b-hf` to `bigdl-7b`.The key point here is that the model's path should include "bigdl" and **should not include paths matched by other model adapters**.
@ -66,10 +71,10 @@ Then we can run model workers
```bash ```bash
# On CPU # On CPU
python3 -m bigdl.llm.serving.model_worker --model-path PATH/TO/bigdl-7b --device cpu python3 -m bigdl.llm.serving.fastchat.model_worker --model-path PATH/TO/bigdl-7b --device cpu
# On GPU # On GPU
python3 -m bigdl.llm.serving.model_worker --model-path PATH/TO/bigdl-7b --device xpu python3 -m bigdl.llm.serving.fastchat.model_worker --model-path PATH/TO/bigdl-7b --device xpu
``` ```
If you run successfully using `BigDL` backend, you can see the output in log like this: If you run successfully using `BigDL` backend, you can see the output in log like this:
@ -78,7 +83,28 @@ If you run successfully using `BigDL` backend, you can see the output in log lik
INFO - Converting the current model to sym_int4 format...... INFO - Converting the current model to sym_int4 format......
``` ```
> note: We currently only support int4 quantization. > note: We currently only support int4 quantization for this method.
</details>
#### BigDL-LLM worker
To integrate BigDL-LLM with `FastChat` efficiently, we have provided a new model_worker implementation named `bigdl_worker.py`.
To run the `bigdl_worker` on CPU, using the following code:
```bash
source bigdl-llm-init -t
# Available low_bit format including sym_int4, sym_int8, bf16 etc.
python3 -m bigdl.llm.serving.fastchat.bigdl_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "cpu"
```
For GPU example:
```bash
# Available low_bit format including sym_int4, sym_int8, fp16 etc.
python3 -m bigdl.llm.serving.fastcaht.bigdl_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "xpu"
```
For a full list of accepted arguments, you can refer to the main method of the `bigdl_worker.py`
#### BigDL vLLM model worker #### BigDL vLLM model worker
@ -88,10 +114,10 @@ To run using the `vLLM_worker`, we don't need to change model name, just simply
```bash ```bash
# On CPU # On CPU
python3 -m bigdl.llm.serving.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device cpu python3 -m bigdl.llm.serving.fastchat.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device cpu
# On GPU # On GPU
python3 -m bigdl.llm.serving.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device xpu python3 -m bigdl.llm.serving.fastchat.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device xpu
``` ```
### Launch Gradio web server ### Launch Gradio web server

View file

@ -0,0 +1,15 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View file

@ -0,0 +1,337 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
A model worker that executes the model based on BigDL-LLM.
Relies on load_model method
"""
import argparse
import asyncio
import atexit
import json
from typing import List
import uuid
from threading import Thread
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.concurrency import run_in_threadpool
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
from fastchat.serve.base_model_worker import BaseModelWorker
from fastchat.serve.model_worker import (
logger,
worker_id,
)
from fastchat.serve.base_model_worker import (
create_background_tasks,
acquire_worker_semaphore,
release_worker_semaphore,
)
from fastchat.utils import get_context_length, is_partial_stop
from bigdl.llm.transformers.loader import load_model
from transformers import TextIteratorStreamer
app = FastAPI()
class BigDLLLMWorker(BaseModelWorker):
def __init__(
self,
controller_addr: str,
worker_addr: str,
worker_id: str,
model_path: str,
model_names: List[str],
limit_worker_concurrency: int,
conv_template: str = None,
load_in_low_bit: str = "sym_int4",
device: str = "cpu",
no_register: bool = False,
trust_remote_code: bool = False,
stream_interval: int = 4,
):
super().__init__(
controller_addr,
worker_addr,
worker_id,
model_path,
model_names,
limit_worker_concurrency,
conv_template,
)
self.load_in_low_bit = load_in_low_bit
logger.info(
f"Loading the model {self.model_names} on worker {worker_id},"
f" worker type: BigDLLLM worker..."
)
logger.info(f"Using low bit format: {self.load_in_low_bit}, device: {device}")
self.device = device
self.model, self.tokenizer = load_model(
model_path, device, self.load_in_low_bit, trust_remote_code
)
self.stream_interval = stream_interval
self.context_len = get_context_length(self.model.config)
if not no_register:
self.init_heart_beat()
def generate_stream_gate(self, params):
self.call_ct += 1
# context length is self.context_length
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", 0)) # 0 means disable
max_new_tokens = int(params.get("max_new_tokens", 256))
echo = bool(params.get("echo", True))
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
if self.tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(self.tokenizer.eos_token_id)
# Handle stop_str
stop = set()
if isinstance(stop_str, str) and stop_str != "":
stop.add(stop_str)
elif isinstance(stop_str, list) and stop_str != []:
stop.update(stop_str)
for tid in stop_token_ids:
if tid is not None:
s = self.tokenizer.decode(tid)
if s != "":
stop.add(s)
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
if self.device == "xpu":
input_ids = input_ids.to("xpu")
input_echo_len = input_ids.shape[1]
if self.model.config.is_encoder_decoder:
max_src_len = self.context_len
input_ids = input_ids[:max_src_len]
input_echo_len = len(input_ids)
prompt = self.tokenizer.decode(
input_ids,
skip_special_tokens=True,
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True,
)
else:
# Truncate the max_new_tokens if input_ids is too long
new_max_new_tokens = min(self.context_len - input_echo_len, max_new_tokens)
if new_max_new_tokens < max_new_tokens:
logger.info(
f"Warning: max_new_tokens[{max_new_tokens}] + prompt[{input_echo_len}] greater "
f"than context_length[{self.context_len}]"
)
logger.info(f"Reset max_new_tokens to {new_max_new_tokens}")
max_new_tokens = new_max_new_tokens
# Use TextIteratorStreamer for streaming output
streamer = TextIteratorStreamer(
tokenizer=self.tokenizer,
timeout=60,
skip_prompt=True,
skip_special_tokens=True,
)
# Generation config:
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig
generated_kwargs = dict(
max_new_tokens=max_new_tokens,
streamer=streamer,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_p=top_p,
top_k=top_k,
)
def model_generate():
self.model.generate(input_ids, **generated_kwargs)
t1 = Thread(target=model_generate)
t1.start()
stopped = False
finish_reason = None
if echo:
partial_output = prompt
rfind_start = len(prompt)
else:
partial_output = ""
rfind_start = 0
for i in range(max_new_tokens):
try:
output_token = next(streamer)
except StopIteration:
# Stop early
stopped = True
break
partial_output += output_token
if i % self.stream_interval == 0 or i == max_new_tokens - 1 or stopped:
for each_stop in stop:
pos = partial_output.rfind(each_stop, rfind_start)
if pos != -1:
partial_output = partial_output[:pos]
stopped = True
break
else:
partially_stopped = is_partial_stop(partial_output, each_stop)
if partially_stopped:
break
if not partially_stopped:
json_output = {
"text": partial_output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
}
ret = {
"text": json_output["text"],
"error_code": 0,
}
ret["usage"] = json_output["usage"]
ret["finish_reason"] = json_output["finish_reason"]
yield json.dumps(ret).encode() + b"\0"
if stopped:
break
else:
finish_reason = "length"
if stopped:
finish_reason = "stop"
json_output = {
"text": partial_output,
"error_code": 0,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": finish_reason,
}
yield json.dumps(json_output).encode() + b"\0"
def generate_gate(self, params):
for x in self.generate_stream_gate(params):
# for x in self.generate_stream2(params):
pass
return json.loads(x[:-1].decode())
# Below are api interfaces
@app.post("/worker_generate_stream")
async def api_generate_stream(request: Request):
params = await request.json()
await acquire_worker_semaphore()
generator = worker.generate_stream_gate(params)
background_tasks = create_background_tasks()
return StreamingResponse(generator, background=background_tasks)
@app.post("/worker_generate")
async def api_generate(request: Request):
params = await request.json()
await acquire_worker_semaphore()
output = await asyncio.to_thread(worker.generate_gate, params)
release_worker_semaphore()
return JSONResponse(output)
@app.post("/worker_get_status")
async def api_get_status(request: Request):
return worker.get_status()
@app.post("/count_token")
async def api_count_token(request: Request):
params = await request.json()
return worker.count_token(params)
@app.post("/worker_get_conv_template")
async def api_get_conv(request: Request):
return worker.get_conv_template()
@app.post("/model_details")
async def api_model_details(request: Request):
return {"context_length": worker.context_len}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.5")
parser.add_argument(
"--model-names",
type=lambda s: s.split(","),
help="Optional display comma separated names",
)
parser.add_argument("--limit-worker-concurrency", type=int, default=1024)
parser.add_argument("--no-register", action="store_true")
parser.add_argument(
"--conv-template", type=str, default=None, help="Conversation prompt template."
)
parser.add_argument("--stream-interval", type=int, default=4)
parser.add_argument(
"--low-bit", type=str, default="sym_int4", help="Low bit format."
)
parser.add_argument(
"--device", type=str, default="cpu", help="Device for executing model, cpu/xpu"
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
default=False,
help="Trust remote code (e.g., from HuggingFace) when"
"downloading the model and tokenizer.",
)
args = parser.parse_args()
worker = BigDLLLMWorker(
args.controller_address,
args.worker_address,
worker_id,
args.model_path,
args.model_names,
args.limit_worker_concurrency,
args.conv_template,
args.low_bit,
args.device,
args.no_register,
args.trust_remote_code,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")