[LLM] Integrate FastChat as a serving framework for BigDL-LLM (#8821)
* Finish changing * format * add licence * Add licence * fix * fix * Add xpu support for fschat * Fix patch * Also install webui dependencies * change setup.py dependency installs * fiox * format * final test
This commit is contained in:
parent
cb534ed5c4
commit
0bf5857908
5 changed files with 895 additions and 1 deletions
|
|
@ -53,6 +53,7 @@ CONVERT_DEP = ['numpy >= 1.22', 'torch',
|
|||
'transformers == 4.31.0', 'sentencepiece',
|
||||
# TODO: Support accelerate 0.22.0
|
||||
'accelerate == 0.21.0', 'tabulate']
|
||||
SERVING_DEP = ['fschat[model_worker, webui] >= 0.2.24', 'protobuf']
|
||||
windows_binarys = [
|
||||
"llama.dll",
|
||||
"gptneox.dll",
|
||||
|
|
@ -253,6 +254,7 @@ def setup_package():
|
|||
|
||||
all_requires = ['py-cpuinfo', 'protobuf']
|
||||
all_requires += CONVERT_DEP
|
||||
all_requires += SERVING_DEP
|
||||
|
||||
# install with -f https://developer.intel.com/ipex-whl-stable-xpu
|
||||
xpu_requires = copy.deepcopy(all_requires)
|
||||
|
|
@ -262,6 +264,10 @@ def setup_package():
|
|||
"intel_extension_for_pytorch==2.0.110+xpu;platform_system=='Linux'",
|
||||
"bigdl-core-xe==" + VERSION + ";platform_system=='Linux'"]
|
||||
|
||||
serving_requires = ['py-cpuinfo']
|
||||
serving_requires += SERVING_DEP
|
||||
|
||||
|
||||
metadata = dict(
|
||||
name='bigdl-llm',
|
||||
version=VERSION,
|
||||
|
|
@ -283,7 +289,8 @@ def setup_package():
|
|||
]
|
||||
},
|
||||
extras_require={"all": all_requires,
|
||||
"xpu": xpu_requires},
|
||||
"xpu": xpu_requires,
|
||||
"serving": serving_requires},
|
||||
classifiers=[
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Programming Language :: Python :: 3',
|
||||
|
|
|
|||
97
python/llm/src/bigdl/llm/serving/README.md
Normal file
97
python/llm/src/bigdl/llm/serving/README.md
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
## Serving using BigDL-LLM and FastChat
|
||||
|
||||
FastChat is an open platform for training, serving, and evaluating large language model based chatbots. You can find the detailed information at their [homepage](https://github.com/lm-sys/FastChat).
|
||||
|
||||
BigDL-LLM can be easily integrated into FastChat so that user can use `BigDL-LLM` as a serving backend in the deployment.
|
||||
|
||||
### Working with BigDL-LLM Serving
|
||||
|
||||
<details><summary>Table of Contents</summary>
|
||||
|
||||
- [Install](#install)
|
||||
- [Models](#models)
|
||||
- [Boot Service](#start-the-service)
|
||||
- [Web GUI](#serving-with-webgui)
|
||||
- [RESTful API](#serving-with-openai-compatible-restful-apis)
|
||||
</details>
|
||||
|
||||
#### Install
|
||||
|
||||
You may install **`bigdl-llm`** with `FastChat` as follows:
|
||||
|
||||
```bash
|
||||
pip install --pre --upgrade bigdl-llm[serving]
|
||||
|
||||
# Or
|
||||
pip install --pre --upgrade bigdl-llm[all]
|
||||
```
|
||||
|
||||
To add GPU support for FastChat, you may install **`bigdl-llm`** as follows:
|
||||
```bash
|
||||
pip install --pre --upgrade bigdl-llm[xpu, serving] -f https://developer.intel.com/ipex-whl-stable-xpu
|
||||
```
|
||||
|
||||
#### Models
|
||||
|
||||
Using BigDL-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat.
|
||||
|
||||
FastChat determines the Model adapter to use through path matching. Therefore, in order to load models using BigDL-LLM, you need to make some modifications to the model's name.
|
||||
|
||||
For instance, assuming you have downloaded the `llama-7b-hf` from [HuggingFace](https://huggingface.co/decapoda-research/llama-7b-hf). Then, to use the `BigDL-LLM` as backend, you need to change name from `llama-7b-hf` to `bigdl-7b`.
|
||||
The key point here is that the model's path should include "bigdl" and should not include paths matched by other model adapters.
|
||||
|
||||
A special case is `ChatGLM` models. For these models, you do not need to do any changes after downloading the model and the `BigDL-LLM` backend will be used automatically.
|
||||
|
||||
|
||||
#### Start the service
|
||||
|
||||
##### Serving with WebGUI
|
||||
|
||||
To serve using the Web UI, you need three main components: web servers that interface with users, model workers that host one or more models, and a controller to coordinate the web server and model workers.
|
||||
|
||||
###### Launch the Controller
|
||||
```bash
|
||||
python3 -m fastchat.serve.controller
|
||||
```
|
||||
|
||||
This controller manages the distributed workers.
|
||||
|
||||
###### Launch the model worker(s)
|
||||
```bash
|
||||
python3 -m bigdl.llm.serving.model_worker --model-path lmsys/vicuna-7b-v1.3 --device cpu
|
||||
```
|
||||
Wait until the process finishes loading the model and you see "Uvicorn running on ...". The model worker will register itself to the controller.
|
||||
|
||||
> To run model worker using Intel GPU, simple change the --device cpu option to --device xpu
|
||||
|
||||
###### Launch the Gradio web server
|
||||
|
||||
```bash
|
||||
python3 -m fastchat.serve.gradio_web_server
|
||||
```
|
||||
|
||||
This is the user interface that users will interact with.
|
||||
|
||||
By following these steps, you will be able to serve your models using the web UI with `BigDL-LLM` as the backend. You can open your browser and chat with a model now.
|
||||
|
||||
##### Serving with OpenAI-Compatible RESTful APIs
|
||||
|
||||
To start an OpenAI API server that provides compatible APIs using `BigDL-LLM` backend, you need three main components: an OpenAI API Server that serves the in-coming requests, model workers that host one or more models, and a controller to coordinate the web server and model workers.
|
||||
|
||||
First, launch the controller
|
||||
|
||||
```bash
|
||||
python3 -m fastchat.serve.controller
|
||||
```
|
||||
|
||||
Then, launch the model worker(s):
|
||||
|
||||
```bash
|
||||
python3 -m bigdl.llm.serving.model_worker --model-path lmsys/vicuna-7b-v1.3 --device cpu
|
||||
```
|
||||
|
||||
Finally, launch the RESTful API server
|
||||
|
||||
```bash
|
||||
python3 -m fastchat.serve.openai_api_server --host localhost --port 8000
|
||||
```
|
||||
15
python/llm/src/bigdl/llm/serving/__init__.py
Normal file
15
python/llm/src/bigdl/llm/serving/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
271
python/llm/src/bigdl/llm/serving/bigdl_llm_model.py
Normal file
271
python/llm/src/bigdl/llm/serving/bigdl_llm_model.py
Normal file
|
|
@ -0,0 +1,271 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from fastchat.model.model_adapter import register_model_adapter, BaseModelAdapter, ChatGLMAdapter
|
||||
from fastchat.modules.gptq import GptqConfig, load_gptq_quantized
|
||||
import accelerate
|
||||
from fastchat.modules.awq import AWQConfig, load_awq_quantized
|
||||
from fastchat.model.model_adapter import (
|
||||
get_model_adapter,
|
||||
raise_warning_for_incompatible_cpu_offloading_configuration,
|
||||
)
|
||||
from fastchat.model.monkey_patch_non_inplace import (
|
||||
replace_llama_attn_with_non_inplace_operations,
|
||||
)
|
||||
from fastchat.constants import CPU_ISA
|
||||
from fastchat.utils import get_gpu_memory
|
||||
import torch
|
||||
import warnings
|
||||
from transformers import AutoTokenizer
|
||||
from typing import Dict, List, Optional
|
||||
import math
|
||||
import psutil
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
is_fastchat_patched = False
|
||||
_mapping_fastchat = None
|
||||
|
||||
|
||||
def _get_patch_map():
|
||||
global _mapping_fastchat
|
||||
|
||||
if _mapping_fastchat is None:
|
||||
_mapping_fastchat = []
|
||||
|
||||
from fastchat.model import model_adapter
|
||||
_mapping_fastchat += [
|
||||
[BaseModelAdapter, "load_model", load_model_base, None],
|
||||
[ChatGLMAdapter, "load_model", load_model_chatglm, None],
|
||||
[model_adapter, "load_model", load_model, None],
|
||||
]
|
||||
|
||||
return _mapping_fastchat
|
||||
|
||||
|
||||
def load_model_base(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
revision = from_pretrained_kwargs.get("revision", "main")
|
||||
print("Customized bigdl-llm loader")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
use_fast=self.use_fast_tokenizer,
|
||||
revision=revision,
|
||||
)
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def load_model_chatglm(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
revision = from_pretrained_kwargs.get("revision", "main")
|
||||
print("Customized bigdl-llm loader")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path, trust_remote_code=True, revision=revision
|
||||
)
|
||||
from bigdl.llm.transformers import AutoModel
|
||||
model = AutoModel.from_pretrained(
|
||||
model_path, trust_remote_code=True, load_in_4bit=True, **from_pretrained_kwargs
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def load_model(
|
||||
model_path: str,
|
||||
device: str = "cuda",
|
||||
num_gpus: int = 1,
|
||||
max_gpu_memory: Optional[str] = None,
|
||||
load_8bit: bool = False,
|
||||
cpu_offloading: bool = False,
|
||||
gptq_config: Optional[GptqConfig] = None,
|
||||
awq_config: Optional[AWQConfig] = None,
|
||||
revision: str = "main",
|
||||
debug: bool = False,
|
||||
):
|
||||
"""Load a model from Hugging Face."""
|
||||
# get model adapter
|
||||
adapter = get_model_adapter(model_path)
|
||||
|
||||
# Handle device mapping
|
||||
cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration(
|
||||
device, load_8bit, cpu_offloading
|
||||
)
|
||||
if device == "cpu":
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if CPU_ISA in ["avx512_bf16", "amx"]:
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
kwargs = {"torch_dtype": torch.bfloat16}
|
||||
except ImportError:
|
||||
warnings.warn(
|
||||
"Intel Extension for PyTorch is not installed, "
|
||||
"it can be installed to accelerate cpu inference"
|
||||
)
|
||||
elif device == "cuda":
|
||||
kwargs = {"torch_dtype": torch.float16}
|
||||
if num_gpus != 1:
|
||||
kwargs["device_map"] = "auto"
|
||||
if max_gpu_memory is None:
|
||||
kwargs[
|
||||
"device_map"
|
||||
] = "sequential" # This is important for not the same VRAM sizes
|
||||
available_gpu_memory = get_gpu_memory(num_gpus)
|
||||
kwargs["max_memory"] = {
|
||||
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
|
||||
for i in range(num_gpus)
|
||||
}
|
||||
else:
|
||||
kwargs["max_memory"] = {
|
||||
i: max_gpu_memory for i in range(num_gpus)}
|
||||
elif device == "mps":
|
||||
kwargs = {"torch_dtype": torch.float16}
|
||||
# Avoid bugs in mps backend by not using in-place operations.
|
||||
replace_llama_attn_with_non_inplace_operations()
|
||||
elif device == "xpu":
|
||||
kwargs = {}
|
||||
# Try to load ipex, while it looks unused, it links into torch for xpu support
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
except ImportError:
|
||||
warnings.warn(
|
||||
"Intel Extension for PyTorch is not installed, but is required for xpu inference."
|
||||
)
|
||||
else:
|
||||
invalidInputError(False, f"Invalid device: {device}")
|
||||
|
||||
if cpu_offloading:
|
||||
# raises an error on incompatible platforms
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
if "max_memory" in kwargs:
|
||||
kwargs["max_memory"]["cpu"] = (
|
||||
str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
|
||||
)
|
||||
kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_8bit_fp32_cpu_offload=cpu_offloading
|
||||
)
|
||||
kwargs["load_in_8bit"] = load_8bit
|
||||
elif load_8bit:
|
||||
if num_gpus != 1:
|
||||
warnings.warn(
|
||||
"8-bit quantization is not supported for multi-gpu inference."
|
||||
)
|
||||
else:
|
||||
model, tokenizer = adapter.load_compress_model(
|
||||
model_path=model_path,
|
||||
device=device,
|
||||
torch_dtype=kwargs["torch_dtype"],
|
||||
revision=revision,
|
||||
)
|
||||
if debug:
|
||||
print(model)
|
||||
return model, tokenizer
|
||||
elif awq_config and awq_config.wbits < 16:
|
||||
invalidInputError(awq_config.wbits != 4,
|
||||
"Currently we only support 4-bit inference for AWQ.")
|
||||
model, tokenizer = load_awq_quantized(model_path, awq_config, device)
|
||||
if num_gpus != 1:
|
||||
device_map = accelerate.infer_auto_device_map(
|
||||
model,
|
||||
max_memory=kwargs["max_memory"],
|
||||
no_split_module_classes=[
|
||||
"OPTDecoderLayer",
|
||||
"LlamaDecoderLayer",
|
||||
"BloomBlock",
|
||||
"MPTBlock",
|
||||
"DecoderLayer",
|
||||
],
|
||||
)
|
||||
model = accelerate.dispatch_model(
|
||||
model, device_map=device_map, offload_buffers=True
|
||||
)
|
||||
else:
|
||||
model.to(device)
|
||||
return model, tokenizer
|
||||
elif gptq_config and gptq_config.wbits < 16:
|
||||
model, tokenizer = load_gptq_quantized(model_path, gptq_config)
|
||||
if num_gpus != 1:
|
||||
device_map = accelerate.infer_auto_device_map(
|
||||
model,
|
||||
max_memory=kwargs["max_memory"],
|
||||
no_split_module_classes=["LlamaDecoderLayer"],
|
||||
)
|
||||
model = accelerate.dispatch_model(
|
||||
model, device_map=device_map, offload_buffers=True
|
||||
)
|
||||
else:
|
||||
model.to(device)
|
||||
return model, tokenizer
|
||||
kwargs["revision"] = revision
|
||||
|
||||
# Load model
|
||||
model, tokenizer = adapter.load_model(model_path, kwargs)
|
||||
|
||||
if (
|
||||
device == "cpu"
|
||||
and kwargs["torch_dtype"] is torch.bfloat16
|
||||
and CPU_ISA is not None
|
||||
):
|
||||
model = ipex.optimize(model, dtype=kwargs["torch_dtype"])
|
||||
|
||||
if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device in (
|
||||
"mps",
|
||||
"xpu",
|
||||
):
|
||||
model.to(device)
|
||||
|
||||
if debug:
|
||||
print(model)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class BigDLLLMAdapter(BaseModelAdapter):
|
||||
"Model adapter for bigdl-llm backend models"
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "bigdl" in model_path
|
||||
|
||||
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
revision = from_pretrained_kwargs.get("revision", "main")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path, use_fast=False, revision=revision
|
||||
)
|
||||
print("Customized bigdl-llm loader")
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
load_in_4bit=True,
|
||||
low_cpu_mem_usage=True,
|
||||
**from_pretrained_kwargs,
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def patch_fastchat():
|
||||
global is_fastchat_patched
|
||||
if is_fastchat_patched:
|
||||
return
|
||||
register_model_adapter(BigDLLLMAdapter)
|
||||
mapping_fastchat = _get_patch_map()
|
||||
|
||||
for mapping_iter in mapping_fastchat:
|
||||
if mapping_iter[3] is None:
|
||||
mapping_iter[3] = getattr(mapping_iter[0], mapping_iter[1], None)
|
||||
setattr(mapping_iter[0], mapping_iter[1], mapping_iter[2])
|
||||
|
||||
is_fastchat_patched = True
|
||||
504
python/llm/src/bigdl/llm/serving/model_worker.py
Normal file
504
python/llm/src/bigdl/llm/serving/model_worker.py
Normal file
|
|
@ -0,0 +1,504 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
A model worker that executes the model.
|
||||
Adapted from FastChat's model_worker.py
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional
|
||||
import threading
|
||||
import uuid
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
from fastapi import FastAPI, Request, BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
import requests
|
||||
|
||||
from .bigdl_llm_model import patch_fastchat
|
||||
|
||||
try:
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
LlamaTokenizer,
|
||||
AutoModel,
|
||||
)
|
||||
except ImportError:
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
LLaMATokenizer,
|
||||
AutoModel,
|
||||
)
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import uvicorn
|
||||
|
||||
from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG
|
||||
from fastchat.conversation import get_conv_template
|
||||
from fastchat.model.model_adapter import (
|
||||
add_model_args,
|
||||
get_conversation_template,
|
||||
get_generate_stream_function,
|
||||
)
|
||||
from fastchat.modules.gptq import GptqConfig
|
||||
from fastchat.modules.awq import AWQConfig
|
||||
from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length
|
||||
|
||||
|
||||
worker_id = str(uuid.uuid4())[:8]
|
||||
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def heart_beat_worker(obj):
|
||||
while True:
|
||||
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
||||
obj.send_heart_beat()
|
||||
|
||||
|
||||
class BaseModelWorker:
|
||||
def __init__(
|
||||
self,
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
worker_id: str,
|
||||
model_path: str,
|
||||
model_names: List[str],
|
||||
limit_worker_concurrency: int,
|
||||
conv_template: str = None,
|
||||
):
|
||||
self.controller_addr = controller_addr
|
||||
self.worker_addr = worker_addr
|
||||
self.worker_id = worker_id
|
||||
if model_path.endswith("/"):
|
||||
model_path = model_path[:-1]
|
||||
self.model_names = model_names or [model_path.split("/")[-1]]
|
||||
self.limit_worker_concurrency = limit_worker_concurrency
|
||||
if conv_template:
|
||||
self.conv = get_conv_template(conv_template)
|
||||
else:
|
||||
self.conv = get_conversation_template(model_path)
|
||||
self.conv.sep_style = int(self.conv.sep_style)
|
||||
self.tokenizer = None
|
||||
self.context_len = None
|
||||
self.call_ct = 0
|
||||
self.semaphore = None
|
||||
|
||||
self.heart_beat_thread = None
|
||||
|
||||
def init_heart_beat(self):
|
||||
self.register_to_controller()
|
||||
self.heart_beat_thread = threading.Thread(
|
||||
target=heart_beat_worker, args=(self,)
|
||||
)
|
||||
self.heart_beat_thread.start()
|
||||
|
||||
def register_to_controller(self):
|
||||
logger.info("Register to controller")
|
||||
|
||||
url = self.controller_addr + "/register_worker"
|
||||
data = {
|
||||
"worker_name": self.worker_addr,
|
||||
"check_heart_beat": True,
|
||||
"worker_status": self.get_status(),
|
||||
}
|
||||
r = requests.post(url, json=data)
|
||||
invalidInputError(r.status_code == 200, "Error register to Controller")
|
||||
|
||||
def send_heart_beat(self):
|
||||
logger.info(
|
||||
f"Send heart beat. Models: {self.model_names}. "
|
||||
f"Semaphore: {pretty_print_semaphore(self.semaphore)}. "
|
||||
f"call_ct: {self.call_ct}. "
|
||||
f"worker_id: {self.worker_id}. "
|
||||
)
|
||||
|
||||
url = self.controller_addr + "/receive_heart_beat"
|
||||
|
||||
while True:
|
||||
try:
|
||||
ret = requests.post(
|
||||
url,
|
||||
json={
|
||||
"worker_name": self.worker_addr,
|
||||
"queue_length": self.get_queue_length(),
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
exist = ret.json()["exist"]
|
||||
break
|
||||
except (requests.exceptions.RequestException, KeyError) as e:
|
||||
logger.error(f"heart beat error: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
if not exist:
|
||||
self.register_to_controller()
|
||||
|
||||
def get_queue_length(self):
|
||||
if (
|
||||
self.semaphore is None
|
||||
or self.semaphore._value is None
|
||||
or self.semaphore._waiters is None
|
||||
):
|
||||
return 0
|
||||
else:
|
||||
return (
|
||||
self.limit_worker_concurrency
|
||||
- self.semaphore._value
|
||||
+ len(self.semaphore._waiters)
|
||||
)
|
||||
|
||||
def get_status(self):
|
||||
return {
|
||||
"model_names": self.model_names,
|
||||
"speed": 1,
|
||||
"queue_length": self.get_queue_length(),
|
||||
}
|
||||
|
||||
def count_token(self, params):
|
||||
prompt = params["prompt"]
|
||||
input_ids = self.tokenizer(prompt).input_ids
|
||||
input_echo_len = len(input_ids)
|
||||
|
||||
ret = {
|
||||
"count": input_echo_len,
|
||||
"error_code": 0,
|
||||
}
|
||||
return ret
|
||||
|
||||
def get_conv_template(self):
|
||||
return {"conv": self.conv}
|
||||
|
||||
|
||||
class ModelWorker(BaseModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
worker_id: str,
|
||||
model_path: str,
|
||||
model_names: List[str],
|
||||
limit_worker_concurrency: int,
|
||||
no_register: bool,
|
||||
device: str,
|
||||
num_gpus: int,
|
||||
max_gpu_memory: str,
|
||||
load_8bit: bool = False,
|
||||
cpu_offloading: bool = False,
|
||||
gptq_config: Optional[GptqConfig] = None,
|
||||
awq_config: Optional[AWQConfig] = None,
|
||||
stream_interval: int = 2,
|
||||
conv_template: str = None,
|
||||
):
|
||||
super().__init__(
|
||||
controller_addr,
|
||||
worker_addr,
|
||||
worker_id,
|
||||
model_path,
|
||||
model_names,
|
||||
limit_worker_concurrency,
|
||||
conv_template=conv_template,
|
||||
)
|
||||
|
||||
logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...")
|
||||
from fastchat.model.model_adapter import load_model
|
||||
self.model, self.tokenizer = load_model(
|
||||
model_path,
|
||||
device=device,
|
||||
num_gpus=num_gpus,
|
||||
max_gpu_memory=max_gpu_memory,
|
||||
load_8bit=load_8bit,
|
||||
cpu_offloading=cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
awq_config=awq_config,
|
||||
)
|
||||
self.device = device
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.context_len = get_context_length(self.model.config)
|
||||
self.generate_stream_func = get_generate_stream_function(self.model, model_path)
|
||||
self.stream_interval = stream_interval
|
||||
|
||||
if not no_register:
|
||||
self.init_heart_beat()
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
self.call_ct += 1
|
||||
|
||||
try:
|
||||
for output in self.generate_stream_func(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
params,
|
||||
self.device,
|
||||
self.context_len,
|
||||
self.stream_interval,
|
||||
):
|
||||
ret = {
|
||||
"text": output["text"],
|
||||
"error_code": 0,
|
||||
}
|
||||
if "usage" in output:
|
||||
ret["usage"] = output["usage"]
|
||||
if "finish_reason" in output:
|
||||
ret["finish_reason"] = output["finish_reason"]
|
||||
if "logprobs" in output:
|
||||
ret["logprobs"] = output["logprobs"]
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
ret = {
|
||||
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
||||
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
|
||||
}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
except (ValueError, RuntimeError) as e:
|
||||
ret = {
|
||||
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
||||
"error_code": ErrorCode.INTERNAL_ERROR,
|
||||
}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
def generate_gate(self, params):
|
||||
for x in self.generate_stream_gate(params):
|
||||
pass
|
||||
return json.loads(x[:-1].decode())
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_embeddings(self, params):
|
||||
self.call_ct += 1
|
||||
|
||||
try:
|
||||
tokenizer = self.tokenizer
|
||||
is_llama = "llama" in str(
|
||||
type(self.model)
|
||||
) # llama supports batch inference
|
||||
is_chatglm = "chatglm" in str(type(self.model))
|
||||
is_t5 = "t5" in str(type(self.model))
|
||||
is_bert = "bert" in str(type(self.model))
|
||||
|
||||
if is_llama:
|
||||
encoding = tokenizer.batch_encode_plus(
|
||||
params["input"], padding=True, return_tensors="pt"
|
||||
)
|
||||
input_ids = encoding["input_ids"].to(self.device)
|
||||
attention_mask = encoding["attention_mask"].to(self.device)
|
||||
model_output = self.model(
|
||||
input_ids, attention_mask, output_hidden_states=True
|
||||
)
|
||||
data = model_output.hidden_states[-1]
|
||||
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
|
||||
masked_embeddings = data * mask
|
||||
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
||||
seq_length = torch.sum(mask, dim=1)
|
||||
embedding = sum_embeddings / seq_length
|
||||
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
|
||||
ret = {
|
||||
"embedding": normalized_embeddings.tolist(),
|
||||
"token_num": torch.sum(attention_mask).item(),
|
||||
}
|
||||
elif is_bert:
|
||||
embedding = []
|
||||
token_num = 0
|
||||
for text in params["input"]:
|
||||
input_ids = tokenizer.encode(text, return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
model_output = self.model(input_ids)
|
||||
data = model_output[0][:, 0]
|
||||
data = F.normalize(torch.mean(data, dim=0), p=2, dim=0)
|
||||
embedding.append(data.tolist())
|
||||
token_num += len(input_ids[0])
|
||||
ret = {
|
||||
"embedding": embedding,
|
||||
"token_num": token_num,
|
||||
}
|
||||
else:
|
||||
embedding = []
|
||||
token_num = 0
|
||||
for text in params["input"]:
|
||||
input_ids = tokenizer.encode(text, return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
if is_t5:
|
||||
model_output = self.model(
|
||||
input_ids, decoder_input_ids=input_ids
|
||||
)
|
||||
else:
|
||||
model_output = self.model(input_ids, output_hidden_states=True)
|
||||
if is_chatglm:
|
||||
data = (model_output.hidden_states[-1].transpose(0, 1))[0]
|
||||
elif is_t5:
|
||||
data = model_output.encoder_last_hidden_state[0]
|
||||
else:
|
||||
data = model_output.hidden_states[-1][0]
|
||||
data = F.normalize(torch.mean(data, dim=0), p=2, dim=0)
|
||||
embedding.append(data.tolist())
|
||||
token_num += len(input_ids[0])
|
||||
ret = {
|
||||
"embedding": embedding,
|
||||
"token_num": token_num,
|
||||
}
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
ret = {
|
||||
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
||||
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
|
||||
}
|
||||
except (ValueError, RuntimeError) as e:
|
||||
ret = {
|
||||
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
||||
"error_code": ErrorCode.INTERNAL_ERROR,
|
||||
}
|
||||
return ret
|
||||
|
||||
|
||||
def release_worker_semaphore():
|
||||
worker.semaphore.release()
|
||||
|
||||
|
||||
def acquire_worker_semaphore():
|
||||
if worker.semaphore is None:
|
||||
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
|
||||
return worker.semaphore.acquire()
|
||||
|
||||
|
||||
def create_background_tasks():
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(release_worker_semaphore)
|
||||
return background_tasks
|
||||
|
||||
|
||||
@app.post("/worker_generate_stream")
|
||||
async def api_generate_stream(request: Request):
|
||||
params = await request.json()
|
||||
await acquire_worker_semaphore()
|
||||
generator = worker.generate_stream_gate(params)
|
||||
background_tasks = create_background_tasks()
|
||||
return StreamingResponse(generator, background=background_tasks)
|
||||
|
||||
|
||||
@app.post("/worker_generate")
|
||||
async def api_generate(request: Request):
|
||||
params = await request.json()
|
||||
await acquire_worker_semaphore()
|
||||
output = worker.generate_gate(params)
|
||||
release_worker_semaphore()
|
||||
return JSONResponse(output)
|
||||
|
||||
|
||||
@app.post("/worker_get_embeddings")
|
||||
async def api_get_embeddings(request: Request):
|
||||
params = await request.json()
|
||||
await acquire_worker_semaphore()
|
||||
embedding = worker.get_embeddings(params)
|
||||
release_worker_semaphore()
|
||||
return JSONResponse(content=embedding)
|
||||
|
||||
|
||||
@app.post("/worker_get_status")
|
||||
async def api_get_status(request: Request):
|
||||
return worker.get_status()
|
||||
|
||||
|
||||
@app.post("/count_token")
|
||||
async def api_count_token(request: Request):
|
||||
params = await request.json()
|
||||
return worker.count_token(params)
|
||||
|
||||
|
||||
@app.post("/worker_get_conv_template")
|
||||
async def api_get_conv(request: Request):
|
||||
return worker.get_conv_template()
|
||||
|
||||
|
||||
@app.post("/model_details")
|
||||
async def api_model_details(request: Request):
|
||||
return {"context_length": worker.context_len}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
patch_fastchat()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=21002)
|
||||
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
||||
parser.add_argument(
|
||||
"--controller-address", type=str, default="http://localhost:21001"
|
||||
)
|
||||
add_model_args(parser)
|
||||
parser.add_argument(
|
||||
"--model-names",
|
||||
type=lambda s: s.split(","),
|
||||
help="Optional display comma separated names",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conv-template", type=str, default=None, help="Conversation prompt template."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit-worker-concurrency",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Limit the model concurrency to prevent OOM.",
|
||||
)
|
||||
parser.add_argument("--stream-interval", type=int, default=2)
|
||||
parser.add_argument("--no-register", action="store_true")
|
||||
args = parser.parse_args()
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
if args.gpus:
|
||||
invalidInputError(len(args.gpus.split(",")) > args.num_gpus, f"Larger --num-gpus "
|
||||
"({args.num_gpus}) than --gpus {args.gpus}!")
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||||
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=args.gptq_ckpt or args.model_path,
|
||||
wbits=args.gptq_wbits,
|
||||
groupsize=args.gptq_groupsize,
|
||||
act_order=args.gptq_act_order,
|
||||
)
|
||||
awq_config = AWQConfig(
|
||||
ckpt=args.awq_ckpt or args.model_path,
|
||||
wbits=args.awq_wbits,
|
||||
groupsize=args.awq_groupsize,
|
||||
)
|
||||
|
||||
worker = ModelWorker(
|
||||
args.controller_address,
|
||||
args.worker_address,
|
||||
worker_id,
|
||||
args.model_path,
|
||||
args.model_names,
|
||||
args.limit_worker_concurrency,
|
||||
no_register=args.no_register,
|
||||
device=args.device,
|
||||
num_gpus=args.num_gpus,
|
||||
max_gpu_memory=args.max_gpu_memory,
|
||||
load_8bit=args.load_8bit,
|
||||
cpu_offloading=args.cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
awq_config=awq_config,
|
||||
stream_interval=args.stream_interval,
|
||||
conv_template=args.conv_template,
|
||||
)
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
Loading…
Reference in a new issue