From 0bf5857908db4b4318e65c30034fa0622813a92c Mon Sep 17 00:00:00 2001 From: Guancheng Fu <110874468+gc-fu@users.noreply.github.com> Date: Wed, 13 Sep 2023 09:28:05 +0800 Subject: [PATCH] [LLM] Integrate FastChat as a serving framework for BigDL-LLM (#8821) * Finish changing * format * add licence * Add licence * fix * fix * Add xpu support for fschat * Fix patch * Also install webui dependencies * change setup.py dependency installs * fiox * format * final test --- python/llm/setup.py | 9 +- python/llm/src/bigdl/llm/serving/README.md | 97 ++++ python/llm/src/bigdl/llm/serving/__init__.py | 15 + .../src/bigdl/llm/serving/bigdl_llm_model.py | 271 ++++++++++ .../llm/src/bigdl/llm/serving/model_worker.py | 504 ++++++++++++++++++ 5 files changed, 895 insertions(+), 1 deletion(-) create mode 100644 python/llm/src/bigdl/llm/serving/README.md create mode 100644 python/llm/src/bigdl/llm/serving/__init__.py create mode 100644 python/llm/src/bigdl/llm/serving/bigdl_llm_model.py create mode 100644 python/llm/src/bigdl/llm/serving/model_worker.py diff --git a/python/llm/setup.py b/python/llm/setup.py index 3af962b1..99957a00 100644 --- a/python/llm/setup.py +++ b/python/llm/setup.py @@ -53,6 +53,7 @@ CONVERT_DEP = ['numpy >= 1.22', 'torch', 'transformers == 4.31.0', 'sentencepiece', # TODO: Support accelerate 0.22.0 'accelerate == 0.21.0', 'tabulate'] +SERVING_DEP = ['fschat[model_worker, webui] >= 0.2.24', 'protobuf'] windows_binarys = [ "llama.dll", "gptneox.dll", @@ -253,6 +254,7 @@ def setup_package(): all_requires = ['py-cpuinfo', 'protobuf'] all_requires += CONVERT_DEP + all_requires += SERVING_DEP # install with -f https://developer.intel.com/ipex-whl-stable-xpu xpu_requires = copy.deepcopy(all_requires) @@ -262,6 +264,10 @@ def setup_package(): "intel_extension_for_pytorch==2.0.110+xpu;platform_system=='Linux'", "bigdl-core-xe==" + VERSION + ";platform_system=='Linux'"] + serving_requires = ['py-cpuinfo'] + serving_requires += SERVING_DEP + + metadata = dict( name='bigdl-llm', version=VERSION, @@ -283,7 +289,8 @@ def setup_package(): ] }, extras_require={"all": all_requires, - "xpu": xpu_requires}, + "xpu": xpu_requires, + "serving": serving_requires}, classifiers=[ 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3', diff --git a/python/llm/src/bigdl/llm/serving/README.md b/python/llm/src/bigdl/llm/serving/README.md new file mode 100644 index 00000000..14aa0253 --- /dev/null +++ b/python/llm/src/bigdl/llm/serving/README.md @@ -0,0 +1,97 @@ +## Serving using BigDL-LLM and FastChat + +FastChat is an open platform for training, serving, and evaluating large language model based chatbots. You can find the detailed information at their [homepage](https://github.com/lm-sys/FastChat). + +BigDL-LLM can be easily integrated into FastChat so that user can use `BigDL-LLM` as a serving backend in the deployment. + +### Working with BigDL-LLM Serving + +
Table of Contents + +- [Install](#install) +- [Models](#models) +- [Boot Service](#start-the-service) + - [Web GUI](#serving-with-webgui) + - [RESTful API](#serving-with-openai-compatible-restful-apis) +
+ +#### Install + +You may install **`bigdl-llm`** with `FastChat` as follows: + +```bash +pip install --pre --upgrade bigdl-llm[serving] + +# Or +pip install --pre --upgrade bigdl-llm[all] +``` + +To add GPU support for FastChat, you may install **`bigdl-llm`** as follows: +```bash +pip install --pre --upgrade bigdl-llm[xpu, serving] -f https://developer.intel.com/ipex-whl-stable-xpu +``` + +#### Models + +Using BigDL-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat. + +FastChat determines the Model adapter to use through path matching. Therefore, in order to load models using BigDL-LLM, you need to make some modifications to the model's name. + +For instance, assuming you have downloaded the `llama-7b-hf` from [HuggingFace](https://huggingface.co/decapoda-research/llama-7b-hf). Then, to use the `BigDL-LLM` as backend, you need to change name from `llama-7b-hf` to `bigdl-7b`. +The key point here is that the model's path should include "bigdl" and should not include paths matched by other model adapters. + +A special case is `ChatGLM` models. For these models, you do not need to do any changes after downloading the model and the `BigDL-LLM` backend will be used automatically. + + +#### Start the service + +##### Serving with WebGUI + +To serve using the Web UI, you need three main components: web servers that interface with users, model workers that host one or more models, and a controller to coordinate the web server and model workers. + +###### Launch the Controller +```bash +python3 -m fastchat.serve.controller +``` + +This controller manages the distributed workers. + +###### Launch the model worker(s) +```bash +python3 -m bigdl.llm.serving.model_worker --model-path lmsys/vicuna-7b-v1.3 --device cpu +``` +Wait until the process finishes loading the model and you see "Uvicorn running on ...". The model worker will register itself to the controller. + +> To run model worker using Intel GPU, simple change the --device cpu option to --device xpu + +###### Launch the Gradio web server + +```bash +python3 -m fastchat.serve.gradio_web_server +``` + +This is the user interface that users will interact with. + +By following these steps, you will be able to serve your models using the web UI with `BigDL-LLM` as the backend. You can open your browser and chat with a model now. + +##### Serving with OpenAI-Compatible RESTful APIs + +To start an OpenAI API server that provides compatible APIs using `BigDL-LLM` backend, you need three main components: an OpenAI API Server that serves the in-coming requests, model workers that host one or more models, and a controller to coordinate the web server and model workers. + +First, launch the controller + +```bash +python3 -m fastchat.serve.controller +``` + +Then, launch the model worker(s): + +```bash +python3 -m bigdl.llm.serving.model_worker --model-path lmsys/vicuna-7b-v1.3 --device cpu +``` + +Finally, launch the RESTful API server + +```bash +python3 -m fastchat.serve.openai_api_server --host localhost --port 8000 +``` \ No newline at end of file diff --git a/python/llm/src/bigdl/llm/serving/__init__.py b/python/llm/src/bigdl/llm/serving/__init__.py new file mode 100644 index 00000000..2151a805 --- /dev/null +++ b/python/llm/src/bigdl/llm/serving/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/llm/src/bigdl/llm/serving/bigdl_llm_model.py b/python/llm/src/bigdl/llm/serving/bigdl_llm_model.py new file mode 100644 index 00000000..0f09b346 --- /dev/null +++ b/python/llm/src/bigdl/llm/serving/bigdl_llm_model.py @@ -0,0 +1,271 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from fastchat.model.model_adapter import register_model_adapter, BaseModelAdapter, ChatGLMAdapter +from fastchat.modules.gptq import GptqConfig, load_gptq_quantized +import accelerate +from fastchat.modules.awq import AWQConfig, load_awq_quantized +from fastchat.model.model_adapter import ( + get_model_adapter, + raise_warning_for_incompatible_cpu_offloading_configuration, +) +from fastchat.model.monkey_patch_non_inplace import ( + replace_llama_attn_with_non_inplace_operations, +) +from fastchat.constants import CPU_ISA +from fastchat.utils import get_gpu_memory +import torch +import warnings +from transformers import AutoTokenizer +from typing import Dict, List, Optional +import math +import psutil +from bigdl.llm.utils.common import invalidInputError + +is_fastchat_patched = False +_mapping_fastchat = None + + +def _get_patch_map(): + global _mapping_fastchat + + if _mapping_fastchat is None: + _mapping_fastchat = [] + + from fastchat.model import model_adapter + _mapping_fastchat += [ + [BaseModelAdapter, "load_model", load_model_base, None], + [ChatGLMAdapter, "load_model", load_model_chatglm, None], + [model_adapter, "load_model", load_model, None], + ] + + return _mapping_fastchat + + +def load_model_base(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + print("Customized bigdl-llm loader") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer, + revision=revision, + ) + from bigdl.llm.transformers import AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + return model, tokenizer + + +def load_model_chatglm(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + print("Customized bigdl-llm loader") + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + from bigdl.llm.transformers import AutoModel + model = AutoModel.from_pretrained( + model_path, trust_remote_code=True, load_in_4bit=True, **from_pretrained_kwargs + ) + return model, tokenizer + + +def load_model( + model_path: str, + device: str = "cuda", + num_gpus: int = 1, + max_gpu_memory: Optional[str] = None, + load_8bit: bool = False, + cpu_offloading: bool = False, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + revision: str = "main", + debug: bool = False, +): + """Load a model from Hugging Face.""" + # get model adapter + adapter = get_model_adapter(model_path) + + # Handle device mapping + cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration( + device, load_8bit, cpu_offloading + ) + if device == "cpu": + kwargs = {"torch_dtype": torch.float32} + if CPU_ISA in ["avx512_bf16", "amx"]: + try: + import intel_extension_for_pytorch as ipex + + kwargs = {"torch_dtype": torch.bfloat16} + except ImportError: + warnings.warn( + "Intel Extension for PyTorch is not installed, " + "it can be installed to accelerate cpu inference" + ) + elif device == "cuda": + kwargs = {"torch_dtype": torch.float16} + if num_gpus != 1: + kwargs["device_map"] = "auto" + if max_gpu_memory is None: + kwargs[ + "device_map" + ] = "sequential" # This is important for not the same VRAM sizes + available_gpu_memory = get_gpu_memory(num_gpus) + kwargs["max_memory"] = { + i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" + for i in range(num_gpus) + } + else: + kwargs["max_memory"] = { + i: max_gpu_memory for i in range(num_gpus)} + elif device == "mps": + kwargs = {"torch_dtype": torch.float16} + # Avoid bugs in mps backend by not using in-place operations. + replace_llama_attn_with_non_inplace_operations() + elif device == "xpu": + kwargs = {} + # Try to load ipex, while it looks unused, it links into torch for xpu support + try: + import intel_extension_for_pytorch as ipex + except ImportError: + warnings.warn( + "Intel Extension for PyTorch is not installed, but is required for xpu inference." + ) + else: + invalidInputError(False, f"Invalid device: {device}") + + if cpu_offloading: + # raises an error on incompatible platforms + from transformers import BitsAndBytesConfig + + if "max_memory" in kwargs: + kwargs["max_memory"]["cpu"] = ( + str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib" + ) + kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_8bit_fp32_cpu_offload=cpu_offloading + ) + kwargs["load_in_8bit"] = load_8bit + elif load_8bit: + if num_gpus != 1: + warnings.warn( + "8-bit quantization is not supported for multi-gpu inference." + ) + else: + model, tokenizer = adapter.load_compress_model( + model_path=model_path, + device=device, + torch_dtype=kwargs["torch_dtype"], + revision=revision, + ) + if debug: + print(model) + return model, tokenizer + elif awq_config and awq_config.wbits < 16: + invalidInputError(awq_config.wbits != 4, + "Currently we only support 4-bit inference for AWQ.") + model, tokenizer = load_awq_quantized(model_path, awq_config, device) + if num_gpus != 1: + device_map = accelerate.infer_auto_device_map( + model, + max_memory=kwargs["max_memory"], + no_split_module_classes=[ + "OPTDecoderLayer", + "LlamaDecoderLayer", + "BloomBlock", + "MPTBlock", + "DecoderLayer", + ], + ) + model = accelerate.dispatch_model( + model, device_map=device_map, offload_buffers=True + ) + else: + model.to(device) + return model, tokenizer + elif gptq_config and gptq_config.wbits < 16: + model, tokenizer = load_gptq_quantized(model_path, gptq_config) + if num_gpus != 1: + device_map = accelerate.infer_auto_device_map( + model, + max_memory=kwargs["max_memory"], + no_split_module_classes=["LlamaDecoderLayer"], + ) + model = accelerate.dispatch_model( + model, device_map=device_map, offload_buffers=True + ) + else: + model.to(device) + return model, tokenizer + kwargs["revision"] = revision + + # Load model + model, tokenizer = adapter.load_model(model_path, kwargs) + + if ( + device == "cpu" + and kwargs["torch_dtype"] is torch.bfloat16 + and CPU_ISA is not None + ): + model = ipex.optimize(model, dtype=kwargs["torch_dtype"]) + + if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device in ( + "mps", + "xpu", + ): + model.to(device) + + if debug: + print(model) + + return model, tokenizer + + +class BigDLLLMAdapter(BaseModelAdapter): + "Model adapter for bigdl-llm backend models" + + def match(self, model_path: str): + return "bigdl" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, revision=revision + ) + print("Customized bigdl-llm loader") + from bigdl.llm.transformers import AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained( + model_path, + load_in_4bit=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + +def patch_fastchat(): + global is_fastchat_patched + if is_fastchat_patched: + return + register_model_adapter(BigDLLLMAdapter) + mapping_fastchat = _get_patch_map() + + for mapping_iter in mapping_fastchat: + if mapping_iter[3] is None: + mapping_iter[3] = getattr(mapping_iter[0], mapping_iter[1], None) + setattr(mapping_iter[0], mapping_iter[1], mapping_iter[2]) + + is_fastchat_patched = True diff --git a/python/llm/src/bigdl/llm/serving/model_worker.py b/python/llm/src/bigdl/llm/serving/model_worker.py new file mode 100644 index 00000000..f2eef6b4 --- /dev/null +++ b/python/llm/src/bigdl/llm/serving/model_worker.py @@ -0,0 +1,504 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A model worker that executes the model. +Adapted from FastChat's model_worker.py +""" +import argparse +import asyncio +import dataclasses +import logging +import json +import os +import time +from typing import List, Optional +import threading +import uuid +from bigdl.llm.utils.common import invalidInputError + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import requests + +from .bigdl_llm_model import patch_fastchat + +try: + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaTokenizer, + AutoModel, + ) +except ImportError: + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LLaMATokenizer, + AutoModel, + ) +import torch +import torch.nn.functional as F +import uvicorn + +from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG +from fastchat.conversation import get_conv_template +from fastchat.model.model_adapter import ( + add_model_args, + get_conversation_template, + get_generate_stream_function, +) +from fastchat.modules.gptq import GptqConfig +from fastchat.modules.awq import AWQConfig +from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length + + +worker_id = str(uuid.uuid4())[:8] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") + +app = FastAPI() + + +def heart_beat_worker(obj): + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + obj.send_heart_beat() + + +class BaseModelWorker: + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + conv_template: str = None, + ): + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + if model_path.endswith("/"): + model_path = model_path[:-1] + self.model_names = model_names or [model_path.split("/")[-1]] + self.limit_worker_concurrency = limit_worker_concurrency + if conv_template: + self.conv = get_conv_template(conv_template) + else: + self.conv = get_conversation_template(model_path) + self.conv.sep_style = int(self.conv.sep_style) + self.tokenizer = None + self.context_len = None + self.call_ct = 0 + self.semaphore = None + + self.heart_beat_thread = None + + def init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, args=(self,) + ) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status(), + } + r = requests.post(url, json=data) + invalidInputError(r.status_code == 200, "Error register to Controller") + + def send_heart_beat(self): + logger.info( + f"Send heart beat. Models: {self.model_names}. " + f"Semaphore: {pretty_print_semaphore(self.semaphore)}. " + f"call_ct: {self.call_ct}. " + f"worker_id: {self.worker_id}. " + ) + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post( + url, + json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length(), + }, + timeout=5, + ) + exist = ret.json()["exist"] + break + except (requests.exceptions.RequestException, KeyError) as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if ( + self.semaphore is None + or self.semaphore._value is None + or self.semaphore._waiters is None + ): + return 0 + else: + return ( + self.limit_worker_concurrency + - self.semaphore._value + + len(self.semaphore._waiters) + ) + + def get_status(self): + return { + "model_names": self.model_names, + "speed": 1, + "queue_length": self.get_queue_length(), + } + + def count_token(self, params): + prompt = params["prompt"] + input_ids = self.tokenizer(prompt).input_ids + input_echo_len = len(input_ids) + + ret = { + "count": input_echo_len, + "error_code": 0, + } + return ret + + def get_conv_template(self): + return {"conv": self.conv} + + +class ModelWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + device: str, + num_gpus: int, + max_gpu_memory: str, + load_8bit: bool = False, + cpu_offloading: bool = False, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + stream_interval: int = 2, + conv_template: str = None, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template=conv_template, + ) + + logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...") + from fastchat.model.model_adapter import load_model + self.model, self.tokenizer = load_model( + model_path, + device=device, + num_gpus=num_gpus, + max_gpu_memory=max_gpu_memory, + load_8bit=load_8bit, + cpu_offloading=cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + ) + self.device = device + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.context_len = get_context_length(self.model.config) + self.generate_stream_func = get_generate_stream_function(self.model, model_path) + self.stream_interval = stream_interval + + if not no_register: + self.init_heart_beat() + + def generate_stream_gate(self, params): + self.call_ct += 1 + + try: + for output in self.generate_stream_func( + self.model, + self.tokenizer, + params, + self.device, + self.context_len, + self.stream_interval, + ): + ret = { + "text": output["text"], + "error_code": 0, + } + if "usage" in output: + ret["usage"] = output["usage"] + if "finish_reason" in output: + ret["finish_reason"] = output["finish_reason"] + if "logprobs" in output: + ret["logprobs"] = output["logprobs"] + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + yield json.dumps(ret).encode() + b"\0" + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + def generate_gate(self, params): + for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + @torch.inference_mode() + def get_embeddings(self, params): + self.call_ct += 1 + + try: + tokenizer = self.tokenizer + is_llama = "llama" in str( + type(self.model) + ) # llama supports batch inference + is_chatglm = "chatglm" in str(type(self.model)) + is_t5 = "t5" in str(type(self.model)) + is_bert = "bert" in str(type(self.model)) + + if is_llama: + encoding = tokenizer.batch_encode_plus( + params["input"], padding=True, return_tensors="pt" + ) + input_ids = encoding["input_ids"].to(self.device) + attention_mask = encoding["attention_mask"].to(self.device) + model_output = self.model( + input_ids, attention_mask, output_hidden_states=True + ) + data = model_output.hidden_states[-1] + mask = attention_mask.unsqueeze(-1).expand(data.size()).float() + masked_embeddings = data * mask + sum_embeddings = torch.sum(masked_embeddings, dim=1) + seq_length = torch.sum(mask, dim=1) + embedding = sum_embeddings / seq_length + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + ret = { + "embedding": normalized_embeddings.tolist(), + "token_num": torch.sum(attention_mask).item(), + } + elif is_bert: + embedding = [] + token_num = 0 + for text in params["input"]: + input_ids = tokenizer.encode(text, return_tensors="pt").to( + self.device + ) + model_output = self.model(input_ids) + data = model_output[0][:, 0] + data = F.normalize(torch.mean(data, dim=0), p=2, dim=0) + embedding.append(data.tolist()) + token_num += len(input_ids[0]) + ret = { + "embedding": embedding, + "token_num": token_num, + } + else: + embedding = [] + token_num = 0 + for text in params["input"]: + input_ids = tokenizer.encode(text, return_tensors="pt").to( + self.device + ) + if is_t5: + model_output = self.model( + input_ids, decoder_input_ids=input_ids + ) + else: + model_output = self.model(input_ids, output_hidden_states=True) + if is_chatglm: + data = (model_output.hidden_states[-1].transpose(0, 1))[0] + elif is_t5: + data = model_output.encoder_last_hidden_state[0] + else: + data = model_output.hidden_states[-1][0] + data = F.normalize(torch.mean(data, dim=0), p=2, dim=0) + embedding.append(data.tolist()) + token_num += len(input_ids[0]) + ret = { + "embedding": embedding, + "token_num": token_num, + } + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return ret + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + output = worker.generate_gate(params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + await acquire_worker_semaphore() + embedding = worker.get_embeddings(params) + release_worker_semaphore() + return JSONResponse(content=embedding) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + patch_fastchat() + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + add_model_args(parser) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", + ) + parser.add_argument("--stream-interval", type=int, default=2) + parser.add_argument("--no-register", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.gpus: + invalidInputError(len(args.gpus.split(",")) > args.num_gpus, f"Larger --num-gpus " + "({args.num_gpus}) than --gpus {args.gpus}!") + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + awq_config = AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ) + + worker = ModelWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + args.limit_worker_concurrency, + no_register=args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + stream_interval=args.stream_interval, + conv_template=args.conv_template, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info")