Add FastChat bigdl_worker (#10493)
* done * fix format * add licence * done * fix doc * refactor folder * add license
This commit is contained in:
parent
8d0ea1b9b3
commit
3a3756b51d
6 changed files with 386 additions and 8 deletions
|
|
@ -11,7 +11,8 @@ BigDL-LLM can be easily integrated into FastChat so that user can use `BigDL-LLM
|
|||
- [Start the service](#start-the-service)
|
||||
- [Launch controller](#launch-controller)
|
||||
- [Launch model worker(s) and load models](#launch-model-workers-and-load-models)
|
||||
- [BigDL model worker](#bigdl-model-worker)
|
||||
- [BigDL model worker (deprecated)](#bigdl-model-worker-deprecated)
|
||||
- [BigDL worker](#bigdl-llm-worker)
|
||||
- [BigDL vLLM model worker](#vllm-model-worker)
|
||||
- [Launch Gradio web server](#launch-gradio-web-server)
|
||||
- [Launch RESTful API server](#launch-restful-api-server)
|
||||
|
|
@ -48,10 +49,14 @@ python3 -m fastchat.serve.controller
|
|||
|
||||
### Launch model worker(s) and load models
|
||||
|
||||
#### BigDL model worker
|
||||
|
||||
Using BigDL-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat.
|
||||
|
||||
#### BigDL model worker (deprecated)
|
||||
<details>
|
||||
<summary>details</summary>
|
||||
|
||||
> Warning: This method has been deprecated, please change to use `BigDL-LLM` [worker](#bigdl-llm-worker) instead.
|
||||
|
||||
FastChat determines the Model adapter to use through path matching. Therefore, in order to load models using BigDL-LLM, you need to make some modifications to the model's name.
|
||||
|
||||
For instance, assuming you have downloaded the `llama-7b-hf` from [HuggingFace](https://huggingface.co/decapoda-research/llama-7b-hf). Then, to use the `BigDL-LLM` as backend, you need to change name from `llama-7b-hf` to `bigdl-7b`.The key point here is that the model's path should include "bigdl" and **should not include paths matched by other model adapters**.
|
||||
|
|
@ -66,10 +71,10 @@ Then we can run model workers
|
|||
|
||||
```bash
|
||||
# On CPU
|
||||
python3 -m bigdl.llm.serving.model_worker --model-path PATH/TO/bigdl-7b --device cpu
|
||||
python3 -m bigdl.llm.serving.fastchat.model_worker --model-path PATH/TO/bigdl-7b --device cpu
|
||||
|
||||
# On GPU
|
||||
python3 -m bigdl.llm.serving.model_worker --model-path PATH/TO/bigdl-7b --device xpu
|
||||
python3 -m bigdl.llm.serving.fastchat.model_worker --model-path PATH/TO/bigdl-7b --device xpu
|
||||
```
|
||||
|
||||
If you run successfully using `BigDL` backend, you can see the output in log like this:
|
||||
|
|
@ -78,7 +83,28 @@ If you run successfully using `BigDL` backend, you can see the output in log lik
|
|||
INFO - Converting the current model to sym_int4 format......
|
||||
```
|
||||
|
||||
> note: We currently only support int4 quantization.
|
||||
> note: We currently only support int4 quantization for this method.
|
||||
</details>
|
||||
|
||||
#### BigDL-LLM worker
|
||||
To integrate BigDL-LLM with `FastChat` efficiently, we have provided a new model_worker implementation named `bigdl_worker.py`.
|
||||
|
||||
To run the `bigdl_worker` on CPU, using the following code:
|
||||
```bash
|
||||
source bigdl-llm-init -t
|
||||
|
||||
# Available low_bit format including sym_int4, sym_int8, bf16 etc.
|
||||
python3 -m bigdl.llm.serving.fastchat.bigdl_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "cpu"
|
||||
```
|
||||
|
||||
|
||||
For GPU example:
|
||||
```bash
|
||||
# Available low_bit format including sym_int4, sym_int8, fp16 etc.
|
||||
python3 -m bigdl.llm.serving.fastcaht.bigdl_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "xpu"
|
||||
```
|
||||
|
||||
For a full list of accepted arguments, you can refer to the main method of the `bigdl_worker.py`
|
||||
|
||||
#### BigDL vLLM model worker
|
||||
|
||||
|
|
@ -88,10 +114,10 @@ To run using the `vLLM_worker`, we don't need to change model name, just simply
|
|||
|
||||
```bash
|
||||
# On CPU
|
||||
python3 -m bigdl.llm.serving.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device cpu
|
||||
python3 -m bigdl.llm.serving.fastchat.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device cpu
|
||||
|
||||
# On GPU
|
||||
python3 -m bigdl.llm.serving.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device xpu
|
||||
python3 -m bigdl.llm.serving.fastchat.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device xpu
|
||||
```
|
||||
|
||||
### Launch Gradio web server
|
||||
15
python/llm/src/bigdl/llm/serving/fastchat/__init__.py
Normal file
15
python/llm/src/bigdl/llm/serving/fastchat/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
337
python/llm/src/bigdl/llm/serving/fastchat/bigdl_worker.py
Normal file
337
python/llm/src/bigdl/llm/serving/fastchat/bigdl_worker.py
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
A model worker that executes the model based on BigDL-LLM.
|
||||
Relies on load_model method
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import atexit
|
||||
import json
|
||||
from typing import List
|
||||
import uuid
|
||||
from threading import Thread
|
||||
from fastapi import FastAPI, Request, BackgroundTasks
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
import uvicorn
|
||||
|
||||
from fastchat.serve.base_model_worker import BaseModelWorker
|
||||
from fastchat.serve.model_worker import (
|
||||
logger,
|
||||
worker_id,
|
||||
)
|
||||
from fastchat.serve.base_model_worker import (
|
||||
create_background_tasks,
|
||||
acquire_worker_semaphore,
|
||||
release_worker_semaphore,
|
||||
)
|
||||
from fastchat.utils import get_context_length, is_partial_stop
|
||||
|
||||
from bigdl.llm.transformers.loader import load_model
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class BigDLLLMWorker(BaseModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
worker_id: str,
|
||||
model_path: str,
|
||||
model_names: List[str],
|
||||
limit_worker_concurrency: int,
|
||||
conv_template: str = None,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
device: str = "cpu",
|
||||
no_register: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
stream_interval: int = 4,
|
||||
):
|
||||
super().__init__(
|
||||
controller_addr,
|
||||
worker_addr,
|
||||
worker_id,
|
||||
model_path,
|
||||
model_names,
|
||||
limit_worker_concurrency,
|
||||
conv_template,
|
||||
)
|
||||
|
||||
self.load_in_low_bit = load_in_low_bit
|
||||
logger.info(
|
||||
f"Loading the model {self.model_names} on worker {worker_id},"
|
||||
f" worker type: BigDLLLM worker..."
|
||||
)
|
||||
|
||||
logger.info(f"Using low bit format: {self.load_in_low_bit}, device: {device}")
|
||||
|
||||
self.device = device
|
||||
|
||||
self.model, self.tokenizer = load_model(
|
||||
model_path, device, self.load_in_low_bit, trust_remote_code
|
||||
)
|
||||
self.stream_interval = stream_interval
|
||||
self.context_len = get_context_length(self.model.config)
|
||||
if not no_register:
|
||||
self.init_heart_beat()
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
self.call_ct += 1
|
||||
# context length is self.context_length
|
||||
prompt = params["prompt"]
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
top_k = int(params.get("top_k", 0)) # 0 means disable
|
||||
max_new_tokens = int(params.get("max_new_tokens", 256))
|
||||
echo = bool(params.get("echo", True))
|
||||
stop_str = params.get("stop", None)
|
||||
stop_token_ids = params.get("stop_token_ids", None) or []
|
||||
if self.tokenizer.eos_token_id not in stop_token_ids:
|
||||
stop_token_ids.append(self.tokenizer.eos_token_id)
|
||||
|
||||
# Handle stop_str
|
||||
stop = set()
|
||||
if isinstance(stop_str, str) and stop_str != "":
|
||||
stop.add(stop_str)
|
||||
elif isinstance(stop_str, list) and stop_str != []:
|
||||
stop.update(stop_str)
|
||||
|
||||
for tid in stop_token_ids:
|
||||
if tid is not None:
|
||||
s = self.tokenizer.decode(tid)
|
||||
if s != "":
|
||||
stop.add(s)
|
||||
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
|
||||
if self.device == "xpu":
|
||||
input_ids = input_ids.to("xpu")
|
||||
|
||||
input_echo_len = input_ids.shape[1]
|
||||
|
||||
if self.model.config.is_encoder_decoder:
|
||||
max_src_len = self.context_len
|
||||
input_ids = input_ids[:max_src_len]
|
||||
input_echo_len = len(input_ids)
|
||||
prompt = self.tokenizer.decode(
|
||||
input_ids,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=False,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
else:
|
||||
# Truncate the max_new_tokens if input_ids is too long
|
||||
new_max_new_tokens = min(self.context_len - input_echo_len, max_new_tokens)
|
||||
if new_max_new_tokens < max_new_tokens:
|
||||
logger.info(
|
||||
f"Warning: max_new_tokens[{max_new_tokens}] + prompt[{input_echo_len}] greater "
|
||||
f"than context_length[{self.context_len}]"
|
||||
)
|
||||
logger.info(f"Reset max_new_tokens to {new_max_new_tokens}")
|
||||
max_new_tokens = new_max_new_tokens
|
||||
|
||||
# Use TextIteratorStreamer for streaming output
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer=self.tokenizer,
|
||||
timeout=60,
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
# Generation config:
|
||||
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig
|
||||
generated_kwargs = dict(
|
||||
max_new_tokens=max_new_tokens,
|
||||
streamer=streamer,
|
||||
temperature=temperature,
|
||||
repetition_penalty=repetition_penalty,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
def model_generate():
|
||||
self.model.generate(input_ids, **generated_kwargs)
|
||||
|
||||
t1 = Thread(target=model_generate)
|
||||
t1.start()
|
||||
|
||||
stopped = False
|
||||
finish_reason = None
|
||||
if echo:
|
||||
partial_output = prompt
|
||||
rfind_start = len(prompt)
|
||||
else:
|
||||
partial_output = ""
|
||||
rfind_start = 0
|
||||
|
||||
for i in range(max_new_tokens):
|
||||
try:
|
||||
output_token = next(streamer)
|
||||
except StopIteration:
|
||||
# Stop early
|
||||
stopped = True
|
||||
break
|
||||
partial_output += output_token
|
||||
|
||||
if i % self.stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||
for each_stop in stop:
|
||||
pos = partial_output.rfind(each_stop, rfind_start)
|
||||
if pos != -1:
|
||||
partial_output = partial_output[:pos]
|
||||
stopped = True
|
||||
break
|
||||
else:
|
||||
partially_stopped = is_partial_stop(partial_output, each_stop)
|
||||
if partially_stopped:
|
||||
break
|
||||
if not partially_stopped:
|
||||
json_output = {
|
||||
"text": partial_output,
|
||||
"usage": {
|
||||
"prompt_tokens": input_echo_len,
|
||||
"completion_tokens": i,
|
||||
"total_tokens": input_echo_len + i,
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
ret = {
|
||||
"text": json_output["text"],
|
||||
"error_code": 0,
|
||||
}
|
||||
ret["usage"] = json_output["usage"]
|
||||
ret["finish_reason"] = json_output["finish_reason"]
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
if stopped:
|
||||
break
|
||||
else:
|
||||
finish_reason = "length"
|
||||
|
||||
if stopped:
|
||||
finish_reason = "stop"
|
||||
json_output = {
|
||||
"text": partial_output,
|
||||
"error_code": 0,
|
||||
"usage": {
|
||||
"prompt_tokens": input_echo_len,
|
||||
"completion_tokens": i,
|
||||
"total_tokens": input_echo_len + i,
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
yield json.dumps(json_output).encode() + b"\0"
|
||||
|
||||
def generate_gate(self, params):
|
||||
for x in self.generate_stream_gate(params):
|
||||
# for x in self.generate_stream2(params):
|
||||
pass
|
||||
return json.loads(x[:-1].decode())
|
||||
|
||||
|
||||
# Below are api interfaces
|
||||
@app.post("/worker_generate_stream")
|
||||
async def api_generate_stream(request: Request):
|
||||
params = await request.json()
|
||||
await acquire_worker_semaphore()
|
||||
generator = worker.generate_stream_gate(params)
|
||||
background_tasks = create_background_tasks()
|
||||
return StreamingResponse(generator, background=background_tasks)
|
||||
|
||||
|
||||
@app.post("/worker_generate")
|
||||
async def api_generate(request: Request):
|
||||
params = await request.json()
|
||||
await acquire_worker_semaphore()
|
||||
output = await asyncio.to_thread(worker.generate_gate, params)
|
||||
release_worker_semaphore()
|
||||
return JSONResponse(output)
|
||||
|
||||
|
||||
@app.post("/worker_get_status")
|
||||
async def api_get_status(request: Request):
|
||||
return worker.get_status()
|
||||
|
||||
|
||||
@app.post("/count_token")
|
||||
async def api_count_token(request: Request):
|
||||
params = await request.json()
|
||||
return worker.count_token(params)
|
||||
|
||||
|
||||
@app.post("/worker_get_conv_template")
|
||||
async def api_get_conv(request: Request):
|
||||
return worker.get_conv_template()
|
||||
|
||||
|
||||
@app.post("/model_details")
|
||||
async def api_model_details(request: Request):
|
||||
return {"context_length": worker.context_len}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=21002)
|
||||
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
||||
parser.add_argument(
|
||||
"--controller-address", type=str, default="http://localhost:21001"
|
||||
)
|
||||
parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.5")
|
||||
parser.add_argument(
|
||||
"--model-names",
|
||||
type=lambda s: s.split(","),
|
||||
help="Optional display comma separated names",
|
||||
)
|
||||
parser.add_argument("--limit-worker-concurrency", type=int, default=1024)
|
||||
parser.add_argument("--no-register", action="store_true")
|
||||
parser.add_argument(
|
||||
"--conv-template", type=str, default=None, help="Conversation prompt template."
|
||||
)
|
||||
parser.add_argument("--stream-interval", type=int, default=4)
|
||||
parser.add_argument(
|
||||
"--low-bit", type=str, default="sym_int4", help="Low bit format."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cpu", help="Device for executing model, cpu/xpu"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Trust remote code (e.g., from HuggingFace) when"
|
||||
"downloading the model and tokenizer.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
worker = BigDLLLMWorker(
|
||||
args.controller_address,
|
||||
args.worker_address,
|
||||
worker_id,
|
||||
args.model_path,
|
||||
args.model_names,
|
||||
args.limit_worker_concurrency,
|
||||
args.conv_template,
|
||||
args.low_bit,
|
||||
args.device,
|
||||
args.no_register,
|
||||
args.trust_remote_code,
|
||||
)
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
Loading…
Reference in a new issue