Add FastChat bigdl_worker (#10493)
* done * fix format * add licence * done * fix doc * refactor folder * add license
This commit is contained in:
parent
8d0ea1b9b3
commit
3a3756b51d
6 changed files with 386 additions and 8 deletions
|
|
@ -11,7 +11,8 @@ BigDL-LLM can be easily integrated into FastChat so that user can use `BigDL-LLM
|
||||||
- [Start the service](#start-the-service)
|
- [Start the service](#start-the-service)
|
||||||
- [Launch controller](#launch-controller)
|
- [Launch controller](#launch-controller)
|
||||||
- [Launch model worker(s) and load models](#launch-model-workers-and-load-models)
|
- [Launch model worker(s) and load models](#launch-model-workers-and-load-models)
|
||||||
- [BigDL model worker](#bigdl-model-worker)
|
- [BigDL model worker (deprecated)](#bigdl-model-worker-deprecated)
|
||||||
|
- [BigDL worker](#bigdl-llm-worker)
|
||||||
- [BigDL vLLM model worker](#vllm-model-worker)
|
- [BigDL vLLM model worker](#vllm-model-worker)
|
||||||
- [Launch Gradio web server](#launch-gradio-web-server)
|
- [Launch Gradio web server](#launch-gradio-web-server)
|
||||||
- [Launch RESTful API server](#launch-restful-api-server)
|
- [Launch RESTful API server](#launch-restful-api-server)
|
||||||
|
|
@ -48,10 +49,14 @@ python3 -m fastchat.serve.controller
|
||||||
|
|
||||||
### Launch model worker(s) and load models
|
### Launch model worker(s) and load models
|
||||||
|
|
||||||
#### BigDL model worker
|
|
||||||
|
|
||||||
Using BigDL-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat.
|
Using BigDL-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat.
|
||||||
|
|
||||||
|
#### BigDL model worker (deprecated)
|
||||||
|
<details>
|
||||||
|
<summary>details</summary>
|
||||||
|
|
||||||
|
> Warning: This method has been deprecated, please change to use `BigDL-LLM` [worker](#bigdl-llm-worker) instead.
|
||||||
|
|
||||||
FastChat determines the Model adapter to use through path matching. Therefore, in order to load models using BigDL-LLM, you need to make some modifications to the model's name.
|
FastChat determines the Model adapter to use through path matching. Therefore, in order to load models using BigDL-LLM, you need to make some modifications to the model's name.
|
||||||
|
|
||||||
For instance, assuming you have downloaded the `llama-7b-hf` from [HuggingFace](https://huggingface.co/decapoda-research/llama-7b-hf). Then, to use the `BigDL-LLM` as backend, you need to change name from `llama-7b-hf` to `bigdl-7b`.The key point here is that the model's path should include "bigdl" and **should not include paths matched by other model adapters**.
|
For instance, assuming you have downloaded the `llama-7b-hf` from [HuggingFace](https://huggingface.co/decapoda-research/llama-7b-hf). Then, to use the `BigDL-LLM` as backend, you need to change name from `llama-7b-hf` to `bigdl-7b`.The key point here is that the model's path should include "bigdl" and **should not include paths matched by other model adapters**.
|
||||||
|
|
@ -66,10 +71,10 @@ Then we can run model workers
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# On CPU
|
# On CPU
|
||||||
python3 -m bigdl.llm.serving.model_worker --model-path PATH/TO/bigdl-7b --device cpu
|
python3 -m bigdl.llm.serving.fastchat.model_worker --model-path PATH/TO/bigdl-7b --device cpu
|
||||||
|
|
||||||
# On GPU
|
# On GPU
|
||||||
python3 -m bigdl.llm.serving.model_worker --model-path PATH/TO/bigdl-7b --device xpu
|
python3 -m bigdl.llm.serving.fastchat.model_worker --model-path PATH/TO/bigdl-7b --device xpu
|
||||||
```
|
```
|
||||||
|
|
||||||
If you run successfully using `BigDL` backend, you can see the output in log like this:
|
If you run successfully using `BigDL` backend, you can see the output in log like this:
|
||||||
|
|
@ -78,7 +83,28 @@ If you run successfully using `BigDL` backend, you can see the output in log lik
|
||||||
INFO - Converting the current model to sym_int4 format......
|
INFO - Converting the current model to sym_int4 format......
|
||||||
```
|
```
|
||||||
|
|
||||||
> note: We currently only support int4 quantization.
|
> note: We currently only support int4 quantization for this method.
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### BigDL-LLM worker
|
||||||
|
To integrate BigDL-LLM with `FastChat` efficiently, we have provided a new model_worker implementation named `bigdl_worker.py`.
|
||||||
|
|
||||||
|
To run the `bigdl_worker` on CPU, using the following code:
|
||||||
|
```bash
|
||||||
|
source bigdl-llm-init -t
|
||||||
|
|
||||||
|
# Available low_bit format including sym_int4, sym_int8, bf16 etc.
|
||||||
|
python3 -m bigdl.llm.serving.fastchat.bigdl_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "cpu"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
For GPU example:
|
||||||
|
```bash
|
||||||
|
# Available low_bit format including sym_int4, sym_int8, fp16 etc.
|
||||||
|
python3 -m bigdl.llm.serving.fastcaht.bigdl_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "xpu"
|
||||||
|
```
|
||||||
|
|
||||||
|
For a full list of accepted arguments, you can refer to the main method of the `bigdl_worker.py`
|
||||||
|
|
||||||
#### BigDL vLLM model worker
|
#### BigDL vLLM model worker
|
||||||
|
|
||||||
|
|
@ -88,10 +114,10 @@ To run using the `vLLM_worker`, we don't need to change model name, just simply
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# On CPU
|
# On CPU
|
||||||
python3 -m bigdl.llm.serving.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device cpu
|
python3 -m bigdl.llm.serving.fastchat.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device cpu
|
||||||
|
|
||||||
# On GPU
|
# On GPU
|
||||||
python3 -m bigdl.llm.serving.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device xpu
|
python3 -m bigdl.llm.serving.fastchat.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device xpu
|
||||||
```
|
```
|
||||||
|
|
||||||
### Launch Gradio web server
|
### Launch Gradio web server
|
||||||
15
python/llm/src/bigdl/llm/serving/fastchat/__init__.py
Normal file
15
python/llm/src/bigdl/llm/serving/fastchat/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
337
python/llm/src/bigdl/llm/serving/fastchat/bigdl_worker.py
Normal file
337
python/llm/src/bigdl/llm/serving/fastchat/bigdl_worker.py
Normal file
|
|
@ -0,0 +1,337 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
A model worker that executes the model based on BigDL-LLM.
|
||||||
|
Relies on load_model method
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import atexit
|
||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
import uuid
|
||||||
|
from threading import Thread
|
||||||
|
from fastapi import FastAPI, Request, BackgroundTasks
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from fastchat.serve.base_model_worker import BaseModelWorker
|
||||||
|
from fastchat.serve.model_worker import (
|
||||||
|
logger,
|
||||||
|
worker_id,
|
||||||
|
)
|
||||||
|
from fastchat.serve.base_model_worker import (
|
||||||
|
create_background_tasks,
|
||||||
|
acquire_worker_semaphore,
|
||||||
|
release_worker_semaphore,
|
||||||
|
)
|
||||||
|
from fastchat.utils import get_context_length, is_partial_stop
|
||||||
|
|
||||||
|
from bigdl.llm.transformers.loader import load_model
|
||||||
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
class BigDLLLMWorker(BaseModelWorker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
controller_addr: str,
|
||||||
|
worker_addr: str,
|
||||||
|
worker_id: str,
|
||||||
|
model_path: str,
|
||||||
|
model_names: List[str],
|
||||||
|
limit_worker_concurrency: int,
|
||||||
|
conv_template: str = None,
|
||||||
|
load_in_low_bit: str = "sym_int4",
|
||||||
|
device: str = "cpu",
|
||||||
|
no_register: bool = False,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
stream_interval: int = 4,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
controller_addr,
|
||||||
|
worker_addr,
|
||||||
|
worker_id,
|
||||||
|
model_path,
|
||||||
|
model_names,
|
||||||
|
limit_worker_concurrency,
|
||||||
|
conv_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.load_in_low_bit = load_in_low_bit
|
||||||
|
logger.info(
|
||||||
|
f"Loading the model {self.model_names} on worker {worker_id},"
|
||||||
|
f" worker type: BigDLLLM worker..."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Using low bit format: {self.load_in_low_bit}, device: {device}")
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.model, self.tokenizer = load_model(
|
||||||
|
model_path, device, self.load_in_low_bit, trust_remote_code
|
||||||
|
)
|
||||||
|
self.stream_interval = stream_interval
|
||||||
|
self.context_len = get_context_length(self.model.config)
|
||||||
|
if not no_register:
|
||||||
|
self.init_heart_beat()
|
||||||
|
|
||||||
|
def generate_stream_gate(self, params):
|
||||||
|
self.call_ct += 1
|
||||||
|
# context length is self.context_length
|
||||||
|
prompt = params["prompt"]
|
||||||
|
temperature = float(params.get("temperature", 1.0))
|
||||||
|
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
||||||
|
top_p = float(params.get("top_p", 1.0))
|
||||||
|
top_k = int(params.get("top_k", 0)) # 0 means disable
|
||||||
|
max_new_tokens = int(params.get("max_new_tokens", 256))
|
||||||
|
echo = bool(params.get("echo", True))
|
||||||
|
stop_str = params.get("stop", None)
|
||||||
|
stop_token_ids = params.get("stop_token_ids", None) or []
|
||||||
|
if self.tokenizer.eos_token_id not in stop_token_ids:
|
||||||
|
stop_token_ids.append(self.tokenizer.eos_token_id)
|
||||||
|
|
||||||
|
# Handle stop_str
|
||||||
|
stop = set()
|
||||||
|
if isinstance(stop_str, str) and stop_str != "":
|
||||||
|
stop.add(stop_str)
|
||||||
|
elif isinstance(stop_str, list) and stop_str != []:
|
||||||
|
stop.update(stop_str)
|
||||||
|
|
||||||
|
for tid in stop_token_ids:
|
||||||
|
if tid is not None:
|
||||||
|
s = self.tokenizer.decode(tid)
|
||||||
|
if s != "":
|
||||||
|
stop.add(s)
|
||||||
|
|
||||||
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
|
||||||
|
if self.device == "xpu":
|
||||||
|
input_ids = input_ids.to("xpu")
|
||||||
|
|
||||||
|
input_echo_len = input_ids.shape[1]
|
||||||
|
|
||||||
|
if self.model.config.is_encoder_decoder:
|
||||||
|
max_src_len = self.context_len
|
||||||
|
input_ids = input_ids[:max_src_len]
|
||||||
|
input_echo_len = len(input_ids)
|
||||||
|
prompt = self.tokenizer.decode(
|
||||||
|
input_ids,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
spaces_between_special_tokens=False,
|
||||||
|
clean_up_tokenization_spaces=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Truncate the max_new_tokens if input_ids is too long
|
||||||
|
new_max_new_tokens = min(self.context_len - input_echo_len, max_new_tokens)
|
||||||
|
if new_max_new_tokens < max_new_tokens:
|
||||||
|
logger.info(
|
||||||
|
f"Warning: max_new_tokens[{max_new_tokens}] + prompt[{input_echo_len}] greater "
|
||||||
|
f"than context_length[{self.context_len}]"
|
||||||
|
)
|
||||||
|
logger.info(f"Reset max_new_tokens to {new_max_new_tokens}")
|
||||||
|
max_new_tokens = new_max_new_tokens
|
||||||
|
|
||||||
|
# Use TextIteratorStreamer for streaming output
|
||||||
|
streamer = TextIteratorStreamer(
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
timeout=60,
|
||||||
|
skip_prompt=True,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generation config:
|
||||||
|
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig
|
||||||
|
generated_kwargs = dict(
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
streamer=streamer,
|
||||||
|
temperature=temperature,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
def model_generate():
|
||||||
|
self.model.generate(input_ids, **generated_kwargs)
|
||||||
|
|
||||||
|
t1 = Thread(target=model_generate)
|
||||||
|
t1.start()
|
||||||
|
|
||||||
|
stopped = False
|
||||||
|
finish_reason = None
|
||||||
|
if echo:
|
||||||
|
partial_output = prompt
|
||||||
|
rfind_start = len(prompt)
|
||||||
|
else:
|
||||||
|
partial_output = ""
|
||||||
|
rfind_start = 0
|
||||||
|
|
||||||
|
for i in range(max_new_tokens):
|
||||||
|
try:
|
||||||
|
output_token = next(streamer)
|
||||||
|
except StopIteration:
|
||||||
|
# Stop early
|
||||||
|
stopped = True
|
||||||
|
break
|
||||||
|
partial_output += output_token
|
||||||
|
|
||||||
|
if i % self.stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||||
|
for each_stop in stop:
|
||||||
|
pos = partial_output.rfind(each_stop, rfind_start)
|
||||||
|
if pos != -1:
|
||||||
|
partial_output = partial_output[:pos]
|
||||||
|
stopped = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
partially_stopped = is_partial_stop(partial_output, each_stop)
|
||||||
|
if partially_stopped:
|
||||||
|
break
|
||||||
|
if not partially_stopped:
|
||||||
|
json_output = {
|
||||||
|
"text": partial_output,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": input_echo_len,
|
||||||
|
"completion_tokens": i,
|
||||||
|
"total_tokens": input_echo_len + i,
|
||||||
|
},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
ret = {
|
||||||
|
"text": json_output["text"],
|
||||||
|
"error_code": 0,
|
||||||
|
}
|
||||||
|
ret["usage"] = json_output["usage"]
|
||||||
|
ret["finish_reason"] = json_output["finish_reason"]
|
||||||
|
yield json.dumps(ret).encode() + b"\0"
|
||||||
|
|
||||||
|
if stopped:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
finish_reason = "length"
|
||||||
|
|
||||||
|
if stopped:
|
||||||
|
finish_reason = "stop"
|
||||||
|
json_output = {
|
||||||
|
"text": partial_output,
|
||||||
|
"error_code": 0,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": input_echo_len,
|
||||||
|
"completion_tokens": i,
|
||||||
|
"total_tokens": input_echo_len + i,
|
||||||
|
},
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
yield json.dumps(json_output).encode() + b"\0"
|
||||||
|
|
||||||
|
def generate_gate(self, params):
|
||||||
|
for x in self.generate_stream_gate(params):
|
||||||
|
# for x in self.generate_stream2(params):
|
||||||
|
pass
|
||||||
|
return json.loads(x[:-1].decode())
|
||||||
|
|
||||||
|
|
||||||
|
# Below are api interfaces
|
||||||
|
@app.post("/worker_generate_stream")
|
||||||
|
async def api_generate_stream(request: Request):
|
||||||
|
params = await request.json()
|
||||||
|
await acquire_worker_semaphore()
|
||||||
|
generator = worker.generate_stream_gate(params)
|
||||||
|
background_tasks = create_background_tasks()
|
||||||
|
return StreamingResponse(generator, background=background_tasks)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/worker_generate")
|
||||||
|
async def api_generate(request: Request):
|
||||||
|
params = await request.json()
|
||||||
|
await acquire_worker_semaphore()
|
||||||
|
output = await asyncio.to_thread(worker.generate_gate, params)
|
||||||
|
release_worker_semaphore()
|
||||||
|
return JSONResponse(output)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/worker_get_status")
|
||||||
|
async def api_get_status(request: Request):
|
||||||
|
return worker.get_status()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/count_token")
|
||||||
|
async def api_count_token(request: Request):
|
||||||
|
params = await request.json()
|
||||||
|
return worker.count_token(params)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/worker_get_conv_template")
|
||||||
|
async def api_get_conv(request: Request):
|
||||||
|
return worker.get_conv_template()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/model_details")
|
||||||
|
async def api_model_details(request: Request):
|
||||||
|
return {"context_length": worker.context_len}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=int, default=21002)
|
||||||
|
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
||||||
|
parser.add_argument(
|
||||||
|
"--controller-address", type=str, default="http://localhost:21001"
|
||||||
|
)
|
||||||
|
parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.5")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-names",
|
||||||
|
type=lambda s: s.split(","),
|
||||||
|
help="Optional display comma separated names",
|
||||||
|
)
|
||||||
|
parser.add_argument("--limit-worker-concurrency", type=int, default=1024)
|
||||||
|
parser.add_argument("--no-register", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--conv-template", type=str, default=None, help="Conversation prompt template."
|
||||||
|
)
|
||||||
|
parser.add_argument("--stream-interval", type=int, default=4)
|
||||||
|
parser.add_argument(
|
||||||
|
"--low-bit", type=str, default="sym_int4", help="Low bit format."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=str, default="cpu", help="Device for executing model, cpu/xpu"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--trust-remote-code",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Trust remote code (e.g., from HuggingFace) when"
|
||||||
|
"downloading the model and tokenizer.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
worker = BigDLLLMWorker(
|
||||||
|
args.controller_address,
|
||||||
|
args.worker_address,
|
||||||
|
worker_id,
|
||||||
|
args.model_path,
|
||||||
|
args.model_names,
|
||||||
|
args.limit_worker_concurrency,
|
||||||
|
args.conv_template,
|
||||||
|
args.low_bit,
|
||||||
|
args.device,
|
||||||
|
args.no_register,
|
||||||
|
args.trust_remote_code,
|
||||||
|
)
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||||
Loading…
Reference in a new issue