[LLM] Integrate FastChat as a serving framework for BigDL-LLM (#8821)

* Finish changing

* format

* add licence

* Add licence

* fix

* fix

* Add xpu support for fschat

* Fix patch

* Also install webui dependencies

* change setup.py dependency installs

* fiox

* format

* final test
This commit is contained in:
Guancheng Fu 2023-09-13 09:28:05 +08:00 committed by GitHub
parent cb534ed5c4
commit 0bf5857908
5 changed files with 895 additions and 1 deletions

View file

@ -53,6 +53,7 @@ CONVERT_DEP = ['numpy >= 1.22', 'torch',
'transformers == 4.31.0', 'sentencepiece',
# TODO: Support accelerate 0.22.0
'accelerate == 0.21.0', 'tabulate']
SERVING_DEP = ['fschat[model_worker, webui] >= 0.2.24', 'protobuf']
windows_binarys = [
"llama.dll",
"gptneox.dll",
@ -253,6 +254,7 @@ def setup_package():
all_requires = ['py-cpuinfo', 'protobuf']
all_requires += CONVERT_DEP
all_requires += SERVING_DEP
# install with -f https://developer.intel.com/ipex-whl-stable-xpu
xpu_requires = copy.deepcopy(all_requires)
@ -262,6 +264,10 @@ def setup_package():
"intel_extension_for_pytorch==2.0.110+xpu;platform_system=='Linux'",
"bigdl-core-xe==" + VERSION + ";platform_system=='Linux'"]
serving_requires = ['py-cpuinfo']
serving_requires += SERVING_DEP
metadata = dict(
name='bigdl-llm',
version=VERSION,
@ -283,7 +289,8 @@ def setup_package():
]
},
extras_require={"all": all_requires,
"xpu": xpu_requires},
"xpu": xpu_requires,
"serving": serving_requires},
classifiers=[
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3',

View file

@ -0,0 +1,97 @@
## Serving using BigDL-LLM and FastChat
FastChat is an open platform for training, serving, and evaluating large language model based chatbots. You can find the detailed information at their [homepage](https://github.com/lm-sys/FastChat).
BigDL-LLM can be easily integrated into FastChat so that user can use `BigDL-LLM` as a serving backend in the deployment.
### Working with BigDL-LLM Serving
<details><summary>Table of Contents</summary>
- [Install](#install)
- [Models](#models)
- [Boot Service](#start-the-service)
- [Web GUI](#serving-with-webgui)
- [RESTful API](#serving-with-openai-compatible-restful-apis)
</details>
#### Install
You may install **`bigdl-llm`** with `FastChat` as follows:
```bash
pip install --pre --upgrade bigdl-llm[serving]
# Or
pip install --pre --upgrade bigdl-llm[all]
```
To add GPU support for FastChat, you may install **`bigdl-llm`** as follows:
```bash
pip install --pre --upgrade bigdl-llm[xpu, serving] -f https://developer.intel.com/ipex-whl-stable-xpu
```
#### Models
Using BigDL-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat.
FastChat determines the Model adapter to use through path matching. Therefore, in order to load models using BigDL-LLM, you need to make some modifications to the model's name.
For instance, assuming you have downloaded the `llama-7b-hf` from [HuggingFace](https://huggingface.co/decapoda-research/llama-7b-hf). Then, to use the `BigDL-LLM` as backend, you need to change name from `llama-7b-hf` to `bigdl-7b`.
The key point here is that the model's path should include "bigdl" and should not include paths matched by other model adapters.
A special case is `ChatGLM` models. For these models, you do not need to do any changes after downloading the model and the `BigDL-LLM` backend will be used automatically.
#### Start the service
##### Serving with WebGUI
To serve using the Web UI, you need three main components: web servers that interface with users, model workers that host one or more models, and a controller to coordinate the web server and model workers.
###### Launch the Controller
```bash
python3 -m fastchat.serve.controller
```
This controller manages the distributed workers.
###### Launch the model worker(s)
```bash
python3 -m bigdl.llm.serving.model_worker --model-path lmsys/vicuna-7b-v1.3 --device cpu
```
Wait until the process finishes loading the model and you see "Uvicorn running on ...". The model worker will register itself to the controller.
> To run model worker using Intel GPU, simple change the --device cpu option to --device xpu
###### Launch the Gradio web server
```bash
python3 -m fastchat.serve.gradio_web_server
```
This is the user interface that users will interact with.
By following these steps, you will be able to serve your models using the web UI with `BigDL-LLM` as the backend. You can open your browser and chat with a model now.
##### Serving with OpenAI-Compatible RESTful APIs
To start an OpenAI API server that provides compatible APIs using `BigDL-LLM` backend, you need three main components: an OpenAI API Server that serves the in-coming requests, model workers that host one or more models, and a controller to coordinate the web server and model workers.
First, launch the controller
```bash
python3 -m fastchat.serve.controller
```
Then, launch the model worker(s):
```bash
python3 -m bigdl.llm.serving.model_worker --model-path lmsys/vicuna-7b-v1.3 --device cpu
```
Finally, launch the RESTful API server
```bash
python3 -m fastchat.serve.openai_api_server --host localhost --port 8000
```

View file

@ -0,0 +1,15 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View file

@ -0,0 +1,271 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fastchat.model.model_adapter import register_model_adapter, BaseModelAdapter, ChatGLMAdapter
from fastchat.modules.gptq import GptqConfig, load_gptq_quantized
import accelerate
from fastchat.modules.awq import AWQConfig, load_awq_quantized
from fastchat.model.model_adapter import (
get_model_adapter,
raise_warning_for_incompatible_cpu_offloading_configuration,
)
from fastchat.model.monkey_patch_non_inplace import (
replace_llama_attn_with_non_inplace_operations,
)
from fastchat.constants import CPU_ISA
from fastchat.utils import get_gpu_memory
import torch
import warnings
from transformers import AutoTokenizer
from typing import Dict, List, Optional
import math
import psutil
from bigdl.llm.utils.common import invalidInputError
is_fastchat_patched = False
_mapping_fastchat = None
def _get_patch_map():
global _mapping_fastchat
if _mapping_fastchat is None:
_mapping_fastchat = []
from fastchat.model import model_adapter
_mapping_fastchat += [
[BaseModelAdapter, "load_model", load_model_base, None],
[ChatGLMAdapter, "load_model", load_model_chatglm, None],
[model_adapter, "load_model", load_model, None],
]
return _mapping_fastchat
def load_model_base(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
print("Customized bigdl-llm loader")
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=self.use_fast_tokenizer,
revision=revision,
)
from bigdl.llm.transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
return model, tokenizer
def load_model_chatglm(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
print("Customized bigdl-llm loader")
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, revision=revision
)
from bigdl.llm.transformers import AutoModel
model = AutoModel.from_pretrained(
model_path, trust_remote_code=True, load_in_4bit=True, **from_pretrained_kwargs
)
return model, tokenizer
def load_model(
model_path: str,
device: str = "cuda",
num_gpus: int = 1,
max_gpu_memory: Optional[str] = None,
load_8bit: bool = False,
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
revision: str = "main",
debug: bool = False,
):
"""Load a model from Hugging Face."""
# get model adapter
adapter = get_model_adapter(model_path)
# Handle device mapping
cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration(
device, load_8bit, cpu_offloading
)
if device == "cpu":
kwargs = {"torch_dtype": torch.float32}
if CPU_ISA in ["avx512_bf16", "amx"]:
try:
import intel_extension_for_pytorch as ipex
kwargs = {"torch_dtype": torch.bfloat16}
except ImportError:
warnings.warn(
"Intel Extension for PyTorch is not installed, "
"it can be installed to accelerate cpu inference"
)
elif device == "cuda":
kwargs = {"torch_dtype": torch.float16}
if num_gpus != 1:
kwargs["device_map"] = "auto"
if max_gpu_memory is None:
kwargs[
"device_map"
] = "sequential" # This is important for not the same VRAM sizes
available_gpu_memory = get_gpu_memory(num_gpus)
kwargs["max_memory"] = {
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
for i in range(num_gpus)
}
else:
kwargs["max_memory"] = {
i: max_gpu_memory for i in range(num_gpus)}
elif device == "mps":
kwargs = {"torch_dtype": torch.float16}
# Avoid bugs in mps backend by not using in-place operations.
replace_llama_attn_with_non_inplace_operations()
elif device == "xpu":
kwargs = {}
# Try to load ipex, while it looks unused, it links into torch for xpu support
try:
import intel_extension_for_pytorch as ipex
except ImportError:
warnings.warn(
"Intel Extension for PyTorch is not installed, but is required for xpu inference."
)
else:
invalidInputError(False, f"Invalid device: {device}")
if cpu_offloading:
# raises an error on incompatible platforms
from transformers import BitsAndBytesConfig
if "max_memory" in kwargs:
kwargs["max_memory"]["cpu"] = (
str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
)
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit_fp32_cpu_offload=cpu_offloading
)
kwargs["load_in_8bit"] = load_8bit
elif load_8bit:
if num_gpus != 1:
warnings.warn(
"8-bit quantization is not supported for multi-gpu inference."
)
else:
model, tokenizer = adapter.load_compress_model(
model_path=model_path,
device=device,
torch_dtype=kwargs["torch_dtype"],
revision=revision,
)
if debug:
print(model)
return model, tokenizer
elif awq_config and awq_config.wbits < 16:
invalidInputError(awq_config.wbits != 4,
"Currently we only support 4-bit inference for AWQ.")
model, tokenizer = load_awq_quantized(model_path, awq_config, device)
if num_gpus != 1:
device_map = accelerate.infer_auto_device_map(
model,
max_memory=kwargs["max_memory"],
no_split_module_classes=[
"OPTDecoderLayer",
"LlamaDecoderLayer",
"BloomBlock",
"MPTBlock",
"DecoderLayer",
],
)
model = accelerate.dispatch_model(
model, device_map=device_map, offload_buffers=True
)
else:
model.to(device)
return model, tokenizer
elif gptq_config and gptq_config.wbits < 16:
model, tokenizer = load_gptq_quantized(model_path, gptq_config)
if num_gpus != 1:
device_map = accelerate.infer_auto_device_map(
model,
max_memory=kwargs["max_memory"],
no_split_module_classes=["LlamaDecoderLayer"],
)
model = accelerate.dispatch_model(
model, device_map=device_map, offload_buffers=True
)
else:
model.to(device)
return model, tokenizer
kwargs["revision"] = revision
# Load model
model, tokenizer = adapter.load_model(model_path, kwargs)
if (
device == "cpu"
and kwargs["torch_dtype"] is torch.bfloat16
and CPU_ISA is not None
):
model = ipex.optimize(model, dtype=kwargs["torch_dtype"])
if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device in (
"mps",
"xpu",
):
model.to(device)
if debug:
print(model)
return model, tokenizer
class BigDLLLMAdapter(BaseModelAdapter):
"Model adapter for bigdl-llm backend models"
def match(self, model_path: str):
return "bigdl" in model_path
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, revision=revision
)
print("Customized bigdl-llm loader")
from bigdl.llm.transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
model_path,
load_in_4bit=True,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
return model, tokenizer
def patch_fastchat():
global is_fastchat_patched
if is_fastchat_patched:
return
register_model_adapter(BigDLLLMAdapter)
mapping_fastchat = _get_patch_map()
for mapping_iter in mapping_fastchat:
if mapping_iter[3] is None:
mapping_iter[3] = getattr(mapping_iter[0], mapping_iter[1], None)
setattr(mapping_iter[0], mapping_iter[1], mapping_iter[2])
is_fastchat_patched = True

View file

@ -0,0 +1,504 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
A model worker that executes the model.
Adapted from FastChat's model_worker.py
"""
import argparse
import asyncio
import dataclasses
import logging
import json
import os
import time
from typing import List, Optional
import threading
import uuid
from bigdl.llm.utils.common import invalidInputError
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse
import requests
from .bigdl_llm_model import patch_fastchat
try:
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LlamaTokenizer,
AutoModel,
)
except ImportError:
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LLaMATokenizer,
AutoModel,
)
import torch
import torch.nn.functional as F
import uvicorn
from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG
from fastchat.conversation import get_conv_template
from fastchat.model.model_adapter import (
add_model_args,
get_conversation_template,
get_generate_stream_function,
)
from fastchat.modules.gptq import GptqConfig
from fastchat.modules.awq import AWQConfig
from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length
worker_id = str(uuid.uuid4())[:8]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
app = FastAPI()
def heart_beat_worker(obj):
while True:
time.sleep(WORKER_HEART_BEAT_INTERVAL)
obj.send_heart_beat()
class BaseModelWorker:
def __init__(
self,
controller_addr: str,
worker_addr: str,
worker_id: str,
model_path: str,
model_names: List[str],
limit_worker_concurrency: int,
conv_template: str = None,
):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
if model_path.endswith("/"):
model_path = model_path[:-1]
self.model_names = model_names or [model_path.split("/")[-1]]
self.limit_worker_concurrency = limit_worker_concurrency
if conv_template:
self.conv = get_conv_template(conv_template)
else:
self.conv = get_conversation_template(model_path)
self.conv.sep_style = int(self.conv.sep_style)
self.tokenizer = None
self.context_len = None
self.call_ct = 0
self.semaphore = None
self.heart_beat_thread = None
def init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=heart_beat_worker, args=(self,)
)
self.heart_beat_thread.start()
def register_to_controller(self):
logger.info("Register to controller")
url = self.controller_addr + "/register_worker"
data = {
"worker_name": self.worker_addr,
"check_heart_beat": True,
"worker_status": self.get_status(),
}
r = requests.post(url, json=data)
invalidInputError(r.status_code == 200, "Error register to Controller")
def send_heart_beat(self):
logger.info(
f"Send heart beat. Models: {self.model_names}. "
f"Semaphore: {pretty_print_semaphore(self.semaphore)}. "
f"call_ct: {self.call_ct}. "
f"worker_id: {self.worker_id}. "
)
url = self.controller_addr + "/receive_heart_beat"
while True:
try:
ret = requests.post(
url,
json={
"worker_name": self.worker_addr,
"queue_length": self.get_queue_length(),
},
timeout=5,
)
exist = ret.json()["exist"]
break
except (requests.exceptions.RequestException, KeyError) as e:
logger.error(f"heart beat error: {e}")
time.sleep(5)
if not exist:
self.register_to_controller()
def get_queue_length(self):
if (
self.semaphore is None
or self.semaphore._value is None
or self.semaphore._waiters is None
):
return 0
else:
return (
self.limit_worker_concurrency
- self.semaphore._value
+ len(self.semaphore._waiters)
)
def get_status(self):
return {
"model_names": self.model_names,
"speed": 1,
"queue_length": self.get_queue_length(),
}
def count_token(self, params):
prompt = params["prompt"]
input_ids = self.tokenizer(prompt).input_ids
input_echo_len = len(input_ids)
ret = {
"count": input_echo_len,
"error_code": 0,
}
return ret
def get_conv_template(self):
return {"conv": self.conv}
class ModelWorker(BaseModelWorker):
def __init__(
self,
controller_addr: str,
worker_addr: str,
worker_id: str,
model_path: str,
model_names: List[str],
limit_worker_concurrency: int,
no_register: bool,
device: str,
num_gpus: int,
max_gpu_memory: str,
load_8bit: bool = False,
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
stream_interval: int = 2,
conv_template: str = None,
):
super().__init__(
controller_addr,
worker_addr,
worker_id,
model_path,
model_names,
limit_worker_concurrency,
conv_template=conv_template,
)
logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...")
from fastchat.model.model_adapter import load_model
self.model, self.tokenizer = load_model(
model_path,
device=device,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
load_8bit=load_8bit,
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
)
self.device = device
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.context_len = get_context_length(self.model.config)
self.generate_stream_func = get_generate_stream_function(self.model, model_path)
self.stream_interval = stream_interval
if not no_register:
self.init_heart_beat()
def generate_stream_gate(self, params):
self.call_ct += 1
try:
for output in self.generate_stream_func(
self.model,
self.tokenizer,
params,
self.device,
self.context_len,
self.stream_interval,
):
ret = {
"text": output["text"],
"error_code": 0,
}
if "usage" in output:
ret["usage"] = output["usage"]
if "finish_reason" in output:
ret["finish_reason"] = output["finish_reason"]
if "logprobs" in output:
ret["logprobs"] = output["logprobs"]
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.OutOfMemoryError as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
}
yield json.dumps(ret).encode() + b"\0"
except (ValueError, RuntimeError) as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.INTERNAL_ERROR,
}
yield json.dumps(ret).encode() + b"\0"
def generate_gate(self, params):
for x in self.generate_stream_gate(params):
pass
return json.loads(x[:-1].decode())
@torch.inference_mode()
def get_embeddings(self, params):
self.call_ct += 1
try:
tokenizer = self.tokenizer
is_llama = "llama" in str(
type(self.model)
) # llama supports batch inference
is_chatglm = "chatglm" in str(type(self.model))
is_t5 = "t5" in str(type(self.model))
is_bert = "bert" in str(type(self.model))
if is_llama:
encoding = tokenizer.batch_encode_plus(
params["input"], padding=True, return_tensors="pt"
)
input_ids = encoding["input_ids"].to(self.device)
attention_mask = encoding["attention_mask"].to(self.device)
model_output = self.model(
input_ids, attention_mask, output_hidden_states=True
)
data = model_output.hidden_states[-1]
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
masked_embeddings = data * mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)
seq_length = torch.sum(mask, dim=1)
embedding = sum_embeddings / seq_length
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
ret = {
"embedding": normalized_embeddings.tolist(),
"token_num": torch.sum(attention_mask).item(),
}
elif is_bert:
embedding = []
token_num = 0
for text in params["input"]:
input_ids = tokenizer.encode(text, return_tensors="pt").to(
self.device
)
model_output = self.model(input_ids)
data = model_output[0][:, 0]
data = F.normalize(torch.mean(data, dim=0), p=2, dim=0)
embedding.append(data.tolist())
token_num += len(input_ids[0])
ret = {
"embedding": embedding,
"token_num": token_num,
}
else:
embedding = []
token_num = 0
for text in params["input"]:
input_ids = tokenizer.encode(text, return_tensors="pt").to(
self.device
)
if is_t5:
model_output = self.model(
input_ids, decoder_input_ids=input_ids
)
else:
model_output = self.model(input_ids, output_hidden_states=True)
if is_chatglm:
data = (model_output.hidden_states[-1].transpose(0, 1))[0]
elif is_t5:
data = model_output.encoder_last_hidden_state[0]
else:
data = model_output.hidden_states[-1][0]
data = F.normalize(torch.mean(data, dim=0), p=2, dim=0)
embedding.append(data.tolist())
token_num += len(input_ids[0])
ret = {
"embedding": embedding,
"token_num": token_num,
}
except torch.cuda.OutOfMemoryError as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
}
except (ValueError, RuntimeError) as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.INTERNAL_ERROR,
}
return ret
def release_worker_semaphore():
worker.semaphore.release()
def acquire_worker_semaphore():
if worker.semaphore is None:
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
return worker.semaphore.acquire()
def create_background_tasks():
background_tasks = BackgroundTasks()
background_tasks.add_task(release_worker_semaphore)
return background_tasks
@app.post("/worker_generate_stream")
async def api_generate_stream(request: Request):
params = await request.json()
await acquire_worker_semaphore()
generator = worker.generate_stream_gate(params)
background_tasks = create_background_tasks()
return StreamingResponse(generator, background=background_tasks)
@app.post("/worker_generate")
async def api_generate(request: Request):
params = await request.json()
await acquire_worker_semaphore()
output = worker.generate_gate(params)
release_worker_semaphore()
return JSONResponse(output)
@app.post("/worker_get_embeddings")
async def api_get_embeddings(request: Request):
params = await request.json()
await acquire_worker_semaphore()
embedding = worker.get_embeddings(params)
release_worker_semaphore()
return JSONResponse(content=embedding)
@app.post("/worker_get_status")
async def api_get_status(request: Request):
return worker.get_status()
@app.post("/count_token")
async def api_count_token(request: Request):
params = await request.json()
return worker.count_token(params)
@app.post("/worker_get_conv_template")
async def api_get_conv(request: Request):
return worker.get_conv_template()
@app.post("/model_details")
async def api_model_details(request: Request):
return {"context_length": worker.context_len}
if __name__ == "__main__":
patch_fastchat()
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
add_model_args(parser)
parser.add_argument(
"--model-names",
type=lambda s: s.split(","),
help="Optional display comma separated names",
)
parser.add_argument(
"--conv-template", type=str, default=None, help="Conversation prompt template."
)
parser.add_argument(
"--limit-worker-concurrency",
type=int,
default=5,
help="Limit the model concurrency to prevent OOM.",
)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
if args.gpus:
invalidInputError(len(args.gpus.split(",")) > args.num_gpus, f"Larger --num-gpus "
"({args.num_gpus}) than --gpus {args.gpus}!")
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
worker = ModelWorker(
args.controller_address,
args.worker_address,
worker_id,
args.model_path,
args.model_names,
args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")