[LLM] Integrate FastChat as a serving framework for BigDL-LLM (#8821)
* Finish changing * format * add licence * Add licence * fix * fix * Add xpu support for fschat * Fix patch * Also install webui dependencies * change setup.py dependency installs * fiox * format * final test
This commit is contained in:
		
							parent
							
								
									cb534ed5c4
								
							
						
					
					
						commit
						0bf5857908
					
				
					 5 changed files with 895 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -53,6 +53,7 @@ CONVERT_DEP = ['numpy >= 1.22', 'torch',
 | 
			
		|||
               'transformers == 4.31.0', 'sentencepiece',
 | 
			
		||||
               # TODO: Support accelerate 0.22.0
 | 
			
		||||
               'accelerate == 0.21.0', 'tabulate']
 | 
			
		||||
SERVING_DEP = ['fschat[model_worker, webui] >= 0.2.24', 'protobuf']
 | 
			
		||||
windows_binarys = [
 | 
			
		||||
    "llama.dll",
 | 
			
		||||
    "gptneox.dll",
 | 
			
		||||
| 
						 | 
				
			
			@ -253,6 +254,7 @@ def setup_package():
 | 
			
		|||
 | 
			
		||||
    all_requires = ['py-cpuinfo', 'protobuf']
 | 
			
		||||
    all_requires += CONVERT_DEP
 | 
			
		||||
    all_requires += SERVING_DEP
 | 
			
		||||
 | 
			
		||||
    # install with -f https://developer.intel.com/ipex-whl-stable-xpu
 | 
			
		||||
    xpu_requires = copy.deepcopy(all_requires)
 | 
			
		||||
| 
						 | 
				
			
			@ -262,6 +264,10 @@ def setup_package():
 | 
			
		|||
                     "intel_extension_for_pytorch==2.0.110+xpu;platform_system=='Linux'",
 | 
			
		||||
                     "bigdl-core-xe==" + VERSION + ";platform_system=='Linux'"]
 | 
			
		||||
 | 
			
		||||
    serving_requires = ['py-cpuinfo']
 | 
			
		||||
    serving_requires += SERVING_DEP
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    metadata = dict(
 | 
			
		||||
        name='bigdl-llm',
 | 
			
		||||
        version=VERSION,
 | 
			
		||||
| 
						 | 
				
			
			@ -283,7 +289,8 @@ def setup_package():
 | 
			
		|||
            ]
 | 
			
		||||
        },
 | 
			
		||||
        extras_require={"all": all_requires,
 | 
			
		||||
                        "xpu": xpu_requires},
 | 
			
		||||
                        "xpu": xpu_requires,
 | 
			
		||||
                        "serving": serving_requires},
 | 
			
		||||
        classifiers=[
 | 
			
		||||
            'License :: OSI Approved :: Apache Software License',
 | 
			
		||||
            'Programming Language :: Python :: 3',
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										97
									
								
								python/llm/src/bigdl/llm/serving/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								python/llm/src/bigdl/llm/serving/README.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,97 @@
 | 
			
		|||
## Serving using BigDL-LLM and FastChat
 | 
			
		||||
 | 
			
		||||
FastChat is an open platform for training, serving, and evaluating large language model based chatbots. You can find the detailed information at their [homepage](https://github.com/lm-sys/FastChat).
 | 
			
		||||
 | 
			
		||||
BigDL-LLM can be easily integrated into FastChat so that user can use `BigDL-LLM` as a serving backend in the deployment.
 | 
			
		||||
 | 
			
		||||
### Working with BigDL-LLM Serving
 | 
			
		||||
 | 
			
		||||
<details><summary>Table of Contents</summary>
 | 
			
		||||
 | 
			
		||||
- [Install](#install)
 | 
			
		||||
- [Models](#models)
 | 
			
		||||
- [Boot Service](#start-the-service)
 | 
			
		||||
  - [Web GUI](#serving-with-webgui)
 | 
			
		||||
  - [RESTful API](#serving-with-openai-compatible-restful-apis)
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
#### Install
 | 
			
		||||
 | 
			
		||||
You may install **`bigdl-llm`** with `FastChat` as follows:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
pip install --pre --upgrade bigdl-llm[serving]
 | 
			
		||||
 | 
			
		||||
# Or
 | 
			
		||||
pip install --pre --upgrade bigdl-llm[all]
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
To add GPU support for FastChat, you may install **`bigdl-llm`** as follows:
 | 
			
		||||
```bash
 | 
			
		||||
pip install --pre --upgrade bigdl-llm[xpu, serving] -f https://developer.intel.com/ipex-whl-stable-xpu
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Models
 | 
			
		||||
 | 
			
		||||
Using BigDL-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat.
 | 
			
		||||
 | 
			
		||||
FastChat determines the Model adapter to use through path matching. Therefore, in order to load models using BigDL-LLM, you need to make some modifications to the model's name.
 | 
			
		||||
 | 
			
		||||
For instance, assuming you have downloaded the `llama-7b-hf` from [HuggingFace](https://huggingface.co/decapoda-research/llama-7b-hf).  Then, to use the `BigDL-LLM` as backend, you need to change name from `llama-7b-hf` to `bigdl-7b`.
 | 
			
		||||
The key point here is that the model's path should include "bigdl" and should not include paths matched by other model adapters.
 | 
			
		||||
 | 
			
		||||
A special case is `ChatGLM` models. For these models, you do not need to do any changes after downloading the model and the `BigDL-LLM` backend will be used automatically.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#### Start the service
 | 
			
		||||
 | 
			
		||||
##### Serving with WebGUI
 | 
			
		||||
 | 
			
		||||
To serve using the Web UI, you need three main components: web servers that interface with users, model workers that host one or more models, and a controller to coordinate the web server and model workers.
 | 
			
		||||
 | 
			
		||||
###### Launch the Controller
 | 
			
		||||
```bash
 | 
			
		||||
python3 -m fastchat.serve.controller
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This controller manages the distributed workers.
 | 
			
		||||
 | 
			
		||||
###### Launch the model worker(s)
 | 
			
		||||
```bash
 | 
			
		||||
python3 -m bigdl.llm.serving.model_worker --model-path lmsys/vicuna-7b-v1.3 --device cpu
 | 
			
		||||
```
 | 
			
		||||
Wait until the process finishes loading the model and you see "Uvicorn running on ...". The model worker will register itself to the controller.
 | 
			
		||||
 | 
			
		||||
> To run model worker using Intel GPU, simple change the --device cpu option to --device xpu
 | 
			
		||||
 | 
			
		||||
###### Launch the Gradio web server
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python3 -m fastchat.serve.gradio_web_server
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This is the user interface that users will interact with.
 | 
			
		||||
 | 
			
		||||
By following these steps, you will be able to serve your models using the web UI with `BigDL-LLM` as the backend. You can open your browser and chat with a model now.
 | 
			
		||||
 | 
			
		||||
##### Serving with OpenAI-Compatible RESTful APIs
 | 
			
		||||
 | 
			
		||||
To start an OpenAI API server that provides compatible APIs using `BigDL-LLM` backend, you need three main components: an OpenAI API Server that serves the in-coming requests, model workers that host one or more models, and a controller to coordinate the web server and model workers.
 | 
			
		||||
 | 
			
		||||
First, launch the controller
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python3 -m fastchat.serve.controller
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Then, launch the model worker(s):
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python3 -m bigdl.llm.serving.model_worker --model-path lmsys/vicuna-7b-v1.3 --device cpu
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Finally, launch the RESTful API server
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python3 -m fastchat.serve.openai_api_server --host localhost --port 8000
 | 
			
		||||
```
 | 
			
		||||
							
								
								
									
										15
									
								
								python/llm/src/bigdl/llm/serving/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								python/llm/src/bigdl/llm/serving/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,15 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
							
								
								
									
										271
									
								
								python/llm/src/bigdl/llm/serving/bigdl_llm_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										271
									
								
								python/llm/src/bigdl/llm/serving/bigdl_llm_model.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,271 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
from fastchat.model.model_adapter import register_model_adapter, BaseModelAdapter, ChatGLMAdapter
 | 
			
		||||
from fastchat.modules.gptq import GptqConfig, load_gptq_quantized
 | 
			
		||||
import accelerate
 | 
			
		||||
from fastchat.modules.awq import AWQConfig, load_awq_quantized
 | 
			
		||||
from fastchat.model.model_adapter import (
 | 
			
		||||
    get_model_adapter,
 | 
			
		||||
    raise_warning_for_incompatible_cpu_offloading_configuration,
 | 
			
		||||
)
 | 
			
		||||
from fastchat.model.monkey_patch_non_inplace import (
 | 
			
		||||
    replace_llama_attn_with_non_inplace_operations,
 | 
			
		||||
)
 | 
			
		||||
from fastchat.constants import CPU_ISA
 | 
			
		||||
from fastchat.utils import get_gpu_memory
 | 
			
		||||
import torch
 | 
			
		||||
import warnings
 | 
			
		||||
from transformers import AutoTokenizer
 | 
			
		||||
from typing import Dict, List, Optional
 | 
			
		||||
import math
 | 
			
		||||
import psutil
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
is_fastchat_patched = False
 | 
			
		||||
_mapping_fastchat = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_patch_map():
 | 
			
		||||
    global _mapping_fastchat
 | 
			
		||||
 | 
			
		||||
    if _mapping_fastchat is None:
 | 
			
		||||
        _mapping_fastchat = []
 | 
			
		||||
 | 
			
		||||
    from fastchat.model import model_adapter
 | 
			
		||||
    _mapping_fastchat += [
 | 
			
		||||
        [BaseModelAdapter, "load_model", load_model_base, None],
 | 
			
		||||
        [ChatGLMAdapter, "load_model", load_model_chatglm, None],
 | 
			
		||||
        [model_adapter, "load_model", load_model, None],
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    return _mapping_fastchat
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_model_base(self, model_path: str, from_pretrained_kwargs: dict):
 | 
			
		||||
    revision = from_pretrained_kwargs.get("revision", "main")
 | 
			
		||||
    print("Customized bigdl-llm loader")
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
        model_path,
 | 
			
		||||
        use_fast=self.use_fast_tokenizer,
 | 
			
		||||
        revision=revision,
 | 
			
		||||
    )
 | 
			
		||||
    from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
        model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
 | 
			
		||||
    )
 | 
			
		||||
    return model, tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_model_chatglm(self, model_path: str, from_pretrained_kwargs: dict):
 | 
			
		||||
    revision = from_pretrained_kwargs.get("revision", "main")
 | 
			
		||||
    print("Customized bigdl-llm loader")
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
        model_path, trust_remote_code=True, revision=revision
 | 
			
		||||
    )
 | 
			
		||||
    from bigdl.llm.transformers import AutoModel
 | 
			
		||||
    model = AutoModel.from_pretrained(
 | 
			
		||||
        model_path, trust_remote_code=True, load_in_4bit=True, **from_pretrained_kwargs
 | 
			
		||||
    )
 | 
			
		||||
    return model, tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_model(
 | 
			
		||||
    model_path: str,
 | 
			
		||||
    device: str = "cuda",
 | 
			
		||||
    num_gpus: int = 1,
 | 
			
		||||
    max_gpu_memory: Optional[str] = None,
 | 
			
		||||
    load_8bit: bool = False,
 | 
			
		||||
    cpu_offloading: bool = False,
 | 
			
		||||
    gptq_config: Optional[GptqConfig] = None,
 | 
			
		||||
    awq_config: Optional[AWQConfig] = None,
 | 
			
		||||
    revision: str = "main",
 | 
			
		||||
    debug: bool = False,
 | 
			
		||||
):
 | 
			
		||||
    """Load a model from Hugging Face."""
 | 
			
		||||
    # get model adapter
 | 
			
		||||
    adapter = get_model_adapter(model_path)
 | 
			
		||||
 | 
			
		||||
    # Handle device mapping
 | 
			
		||||
    cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration(
 | 
			
		||||
        device, load_8bit, cpu_offloading
 | 
			
		||||
    )
 | 
			
		||||
    if device == "cpu":
 | 
			
		||||
        kwargs = {"torch_dtype": torch.float32}
 | 
			
		||||
        if CPU_ISA in ["avx512_bf16", "amx"]:
 | 
			
		||||
            try:
 | 
			
		||||
                import intel_extension_for_pytorch as ipex
 | 
			
		||||
 | 
			
		||||
                kwargs = {"torch_dtype": torch.bfloat16}
 | 
			
		||||
            except ImportError:
 | 
			
		||||
                warnings.warn(
 | 
			
		||||
                    "Intel Extension for PyTorch is not installed, "
 | 
			
		||||
                    "it can be installed to accelerate cpu inference"
 | 
			
		||||
                )
 | 
			
		||||
    elif device == "cuda":
 | 
			
		||||
        kwargs = {"torch_dtype": torch.float16}
 | 
			
		||||
        if num_gpus != 1:
 | 
			
		||||
            kwargs["device_map"] = "auto"
 | 
			
		||||
            if max_gpu_memory is None:
 | 
			
		||||
                kwargs[
 | 
			
		||||
                    "device_map"
 | 
			
		||||
                ] = "sequential"  # This is important for not the same VRAM sizes
 | 
			
		||||
                available_gpu_memory = get_gpu_memory(num_gpus)
 | 
			
		||||
                kwargs["max_memory"] = {
 | 
			
		||||
                    i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
 | 
			
		||||
                    for i in range(num_gpus)
 | 
			
		||||
                }
 | 
			
		||||
            else:
 | 
			
		||||
                kwargs["max_memory"] = {
 | 
			
		||||
                    i: max_gpu_memory for i in range(num_gpus)}
 | 
			
		||||
    elif device == "mps":
 | 
			
		||||
        kwargs = {"torch_dtype": torch.float16}
 | 
			
		||||
        # Avoid bugs in mps backend by not using in-place operations.
 | 
			
		||||
        replace_llama_attn_with_non_inplace_operations()
 | 
			
		||||
    elif device == "xpu":
 | 
			
		||||
        kwargs = {}
 | 
			
		||||
        # Try to load ipex, while it looks unused, it links into torch for xpu support
 | 
			
		||||
        try:
 | 
			
		||||
            import intel_extension_for_pytorch as ipex
 | 
			
		||||
        except ImportError:
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                "Intel Extension for PyTorch is not installed, but is required for xpu inference."
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False, f"Invalid device: {device}")
 | 
			
		||||
 | 
			
		||||
    if cpu_offloading:
 | 
			
		||||
        # raises an error on incompatible platforms
 | 
			
		||||
        from transformers import BitsAndBytesConfig
 | 
			
		||||
 | 
			
		||||
        if "max_memory" in kwargs:
 | 
			
		||||
            kwargs["max_memory"]["cpu"] = (
 | 
			
		||||
                str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
 | 
			
		||||
            )
 | 
			
		||||
        kwargs["quantization_config"] = BitsAndBytesConfig(
 | 
			
		||||
            load_in_8bit_fp32_cpu_offload=cpu_offloading
 | 
			
		||||
        )
 | 
			
		||||
        kwargs["load_in_8bit"] = load_8bit
 | 
			
		||||
    elif load_8bit:
 | 
			
		||||
        if num_gpus != 1:
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                "8-bit quantization is not supported for multi-gpu inference."
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            model, tokenizer = adapter.load_compress_model(
 | 
			
		||||
                model_path=model_path,
 | 
			
		||||
                device=device,
 | 
			
		||||
                torch_dtype=kwargs["torch_dtype"],
 | 
			
		||||
                revision=revision,
 | 
			
		||||
            )
 | 
			
		||||
            if debug:
 | 
			
		||||
                print(model)
 | 
			
		||||
            return model, tokenizer
 | 
			
		||||
    elif awq_config and awq_config.wbits < 16:
 | 
			
		||||
        invalidInputError(awq_config.wbits != 4,
 | 
			
		||||
                          "Currently we only support 4-bit inference for AWQ.")
 | 
			
		||||
        model, tokenizer = load_awq_quantized(model_path, awq_config, device)
 | 
			
		||||
        if num_gpus != 1:
 | 
			
		||||
            device_map = accelerate.infer_auto_device_map(
 | 
			
		||||
                model,
 | 
			
		||||
                max_memory=kwargs["max_memory"],
 | 
			
		||||
                no_split_module_classes=[
 | 
			
		||||
                    "OPTDecoderLayer",
 | 
			
		||||
                    "LlamaDecoderLayer",
 | 
			
		||||
                    "BloomBlock",
 | 
			
		||||
                    "MPTBlock",
 | 
			
		||||
                    "DecoderLayer",
 | 
			
		||||
                ],
 | 
			
		||||
            )
 | 
			
		||||
            model = accelerate.dispatch_model(
 | 
			
		||||
                model, device_map=device_map, offload_buffers=True
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            model.to(device)
 | 
			
		||||
        return model, tokenizer
 | 
			
		||||
    elif gptq_config and gptq_config.wbits < 16:
 | 
			
		||||
        model, tokenizer = load_gptq_quantized(model_path, gptq_config)
 | 
			
		||||
        if num_gpus != 1:
 | 
			
		||||
            device_map = accelerate.infer_auto_device_map(
 | 
			
		||||
                model,
 | 
			
		||||
                max_memory=kwargs["max_memory"],
 | 
			
		||||
                no_split_module_classes=["LlamaDecoderLayer"],
 | 
			
		||||
            )
 | 
			
		||||
            model = accelerate.dispatch_model(
 | 
			
		||||
                model, device_map=device_map, offload_buffers=True
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            model.to(device)
 | 
			
		||||
        return model, tokenizer
 | 
			
		||||
    kwargs["revision"] = revision
 | 
			
		||||
 | 
			
		||||
    # Load model
 | 
			
		||||
    model, tokenizer = adapter.load_model(model_path, kwargs)
 | 
			
		||||
 | 
			
		||||
    if (
 | 
			
		||||
        device == "cpu"
 | 
			
		||||
        and kwargs["torch_dtype"] is torch.bfloat16
 | 
			
		||||
        and CPU_ISA is not None
 | 
			
		||||
    ):
 | 
			
		||||
        model = ipex.optimize(model, dtype=kwargs["torch_dtype"])
 | 
			
		||||
 | 
			
		||||
    if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device in (
 | 
			
		||||
        "mps",
 | 
			
		||||
        "xpu",
 | 
			
		||||
    ):
 | 
			
		||||
        model.to(device)
 | 
			
		||||
 | 
			
		||||
    if debug:
 | 
			
		||||
        print(model)
 | 
			
		||||
 | 
			
		||||
    return model, tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BigDLLLMAdapter(BaseModelAdapter):
 | 
			
		||||
    "Model adapter for bigdl-llm backend models"
 | 
			
		||||
 | 
			
		||||
    def match(self, model_path: str):
 | 
			
		||||
        return "bigdl" in model_path
 | 
			
		||||
 | 
			
		||||
    def load_model(self, model_path: str, from_pretrained_kwargs: dict):
 | 
			
		||||
        revision = from_pretrained_kwargs.get("revision", "main")
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
            model_path, use_fast=False, revision=revision
 | 
			
		||||
        )
 | 
			
		||||
        print("Customized bigdl-llm loader")
 | 
			
		||||
        from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
        model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
            model_path,
 | 
			
		||||
            load_in_4bit=True,
 | 
			
		||||
            low_cpu_mem_usage=True,
 | 
			
		||||
            **from_pretrained_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
        return model, tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_fastchat():
 | 
			
		||||
    global is_fastchat_patched
 | 
			
		||||
    if is_fastchat_patched:
 | 
			
		||||
        return
 | 
			
		||||
    register_model_adapter(BigDLLLMAdapter)
 | 
			
		||||
    mapping_fastchat = _get_patch_map()
 | 
			
		||||
 | 
			
		||||
    for mapping_iter in mapping_fastchat:
 | 
			
		||||
        if mapping_iter[3] is None:
 | 
			
		||||
            mapping_iter[3] = getattr(mapping_iter[0], mapping_iter[1], None)
 | 
			
		||||
        setattr(mapping_iter[0], mapping_iter[1], mapping_iter[2])
 | 
			
		||||
 | 
			
		||||
    is_fastchat_patched = True
 | 
			
		||||
							
								
								
									
										504
									
								
								python/llm/src/bigdl/llm/serving/model_worker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										504
									
								
								python/llm/src/bigdl/llm/serving/model_worker.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,504 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
A model worker that executes the model.
 | 
			
		||||
Adapted from FastChat's model_worker.py
 | 
			
		||||
"""
 | 
			
		||||
import argparse
 | 
			
		||||
import asyncio
 | 
			
		||||
import dataclasses
 | 
			
		||||
import logging
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import time
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
import threading
 | 
			
		||||
import uuid
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
from fastapi import FastAPI, Request, BackgroundTasks
 | 
			
		||||
from fastapi.responses import StreamingResponse, JSONResponse
 | 
			
		||||
import requests
 | 
			
		||||
 | 
			
		||||
from .bigdl_llm_model import patch_fastchat
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from transformers import (
 | 
			
		||||
        AutoTokenizer,
 | 
			
		||||
        AutoModelForCausalLM,
 | 
			
		||||
        LlamaTokenizer,
 | 
			
		||||
        AutoModel,
 | 
			
		||||
    )
 | 
			
		||||
except ImportError:
 | 
			
		||||
    from transformers import (
 | 
			
		||||
        AutoTokenizer,
 | 
			
		||||
        AutoModelForCausalLM,
 | 
			
		||||
        LLaMATokenizer,
 | 
			
		||||
        AutoModel,
 | 
			
		||||
    )
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import uvicorn
 | 
			
		||||
 | 
			
		||||
from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG
 | 
			
		||||
from fastchat.conversation import get_conv_template
 | 
			
		||||
from fastchat.model.model_adapter import (
 | 
			
		||||
    add_model_args,
 | 
			
		||||
    get_conversation_template,
 | 
			
		||||
    get_generate_stream_function,
 | 
			
		||||
)
 | 
			
		||||
from fastchat.modules.gptq import GptqConfig
 | 
			
		||||
from fastchat.modules.awq import AWQConfig
 | 
			
		||||
from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
worker_id = str(uuid.uuid4())[:8]
 | 
			
		||||
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
 | 
			
		||||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def heart_beat_worker(obj):
 | 
			
		||||
    while True:
 | 
			
		||||
        time.sleep(WORKER_HEART_BEAT_INTERVAL)
 | 
			
		||||
        obj.send_heart_beat()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseModelWorker:
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        controller_addr: str,
 | 
			
		||||
        worker_addr: str,
 | 
			
		||||
        worker_id: str,
 | 
			
		||||
        model_path: str,
 | 
			
		||||
        model_names: List[str],
 | 
			
		||||
        limit_worker_concurrency: int,
 | 
			
		||||
        conv_template: str = None,
 | 
			
		||||
    ):
 | 
			
		||||
        self.controller_addr = controller_addr
 | 
			
		||||
        self.worker_addr = worker_addr
 | 
			
		||||
        self.worker_id = worker_id
 | 
			
		||||
        if model_path.endswith("/"):
 | 
			
		||||
            model_path = model_path[:-1]
 | 
			
		||||
        self.model_names = model_names or [model_path.split("/")[-1]]
 | 
			
		||||
        self.limit_worker_concurrency = limit_worker_concurrency
 | 
			
		||||
        if conv_template:
 | 
			
		||||
            self.conv = get_conv_template(conv_template)
 | 
			
		||||
        else:
 | 
			
		||||
            self.conv = get_conversation_template(model_path)
 | 
			
		||||
        self.conv.sep_style = int(self.conv.sep_style)
 | 
			
		||||
        self.tokenizer = None
 | 
			
		||||
        self.context_len = None
 | 
			
		||||
        self.call_ct = 0
 | 
			
		||||
        self.semaphore = None
 | 
			
		||||
 | 
			
		||||
        self.heart_beat_thread = None
 | 
			
		||||
 | 
			
		||||
    def init_heart_beat(self):
 | 
			
		||||
        self.register_to_controller()
 | 
			
		||||
        self.heart_beat_thread = threading.Thread(
 | 
			
		||||
            target=heart_beat_worker, args=(self,)
 | 
			
		||||
        )
 | 
			
		||||
        self.heart_beat_thread.start()
 | 
			
		||||
 | 
			
		||||
    def register_to_controller(self):
 | 
			
		||||
        logger.info("Register to controller")
 | 
			
		||||
 | 
			
		||||
        url = self.controller_addr + "/register_worker"
 | 
			
		||||
        data = {
 | 
			
		||||
            "worker_name": self.worker_addr,
 | 
			
		||||
            "check_heart_beat": True,
 | 
			
		||||
            "worker_status": self.get_status(),
 | 
			
		||||
        }
 | 
			
		||||
        r = requests.post(url, json=data)
 | 
			
		||||
        invalidInputError(r.status_code == 200, "Error register to Controller")
 | 
			
		||||
 | 
			
		||||
    def send_heart_beat(self):
 | 
			
		||||
        logger.info(
 | 
			
		||||
            f"Send heart beat. Models: {self.model_names}. "
 | 
			
		||||
            f"Semaphore: {pretty_print_semaphore(self.semaphore)}. "
 | 
			
		||||
            f"call_ct: {self.call_ct}. "
 | 
			
		||||
            f"worker_id: {self.worker_id}. "
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        url = self.controller_addr + "/receive_heart_beat"
 | 
			
		||||
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                ret = requests.post(
 | 
			
		||||
                    url,
 | 
			
		||||
                    json={
 | 
			
		||||
                        "worker_name": self.worker_addr,
 | 
			
		||||
                        "queue_length": self.get_queue_length(),
 | 
			
		||||
                    },
 | 
			
		||||
                    timeout=5,
 | 
			
		||||
                )
 | 
			
		||||
                exist = ret.json()["exist"]
 | 
			
		||||
                break
 | 
			
		||||
            except (requests.exceptions.RequestException, KeyError) as e:
 | 
			
		||||
                logger.error(f"heart beat error: {e}")
 | 
			
		||||
            time.sleep(5)
 | 
			
		||||
 | 
			
		||||
        if not exist:
 | 
			
		||||
            self.register_to_controller()
 | 
			
		||||
 | 
			
		||||
    def get_queue_length(self):
 | 
			
		||||
        if (
 | 
			
		||||
            self.semaphore is None
 | 
			
		||||
            or self.semaphore._value is None
 | 
			
		||||
            or self.semaphore._waiters is None
 | 
			
		||||
        ):
 | 
			
		||||
            return 0
 | 
			
		||||
        else:
 | 
			
		||||
            return (
 | 
			
		||||
                self.limit_worker_concurrency
 | 
			
		||||
                - self.semaphore._value
 | 
			
		||||
                + len(self.semaphore._waiters)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def get_status(self):
 | 
			
		||||
        return {
 | 
			
		||||
            "model_names": self.model_names,
 | 
			
		||||
            "speed": 1,
 | 
			
		||||
            "queue_length": self.get_queue_length(),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def count_token(self, params):
 | 
			
		||||
        prompt = params["prompt"]
 | 
			
		||||
        input_ids = self.tokenizer(prompt).input_ids
 | 
			
		||||
        input_echo_len = len(input_ids)
 | 
			
		||||
 | 
			
		||||
        ret = {
 | 
			
		||||
            "count": input_echo_len,
 | 
			
		||||
            "error_code": 0,
 | 
			
		||||
        }
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
    def get_conv_template(self):
 | 
			
		||||
        return {"conv": self.conv}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelWorker(BaseModelWorker):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        controller_addr: str,
 | 
			
		||||
        worker_addr: str,
 | 
			
		||||
        worker_id: str,
 | 
			
		||||
        model_path: str,
 | 
			
		||||
        model_names: List[str],
 | 
			
		||||
        limit_worker_concurrency: int,
 | 
			
		||||
        no_register: bool,
 | 
			
		||||
        device: str,
 | 
			
		||||
        num_gpus: int,
 | 
			
		||||
        max_gpu_memory: str,
 | 
			
		||||
        load_8bit: bool = False,
 | 
			
		||||
        cpu_offloading: bool = False,
 | 
			
		||||
        gptq_config: Optional[GptqConfig] = None,
 | 
			
		||||
        awq_config: Optional[AWQConfig] = None,
 | 
			
		||||
        stream_interval: int = 2,
 | 
			
		||||
        conv_template: str = None,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            controller_addr,
 | 
			
		||||
            worker_addr,
 | 
			
		||||
            worker_id,
 | 
			
		||||
            model_path,
 | 
			
		||||
            model_names,
 | 
			
		||||
            limit_worker_concurrency,
 | 
			
		||||
            conv_template=conv_template,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...")
 | 
			
		||||
        from fastchat.model.model_adapter import load_model
 | 
			
		||||
        self.model, self.tokenizer = load_model(
 | 
			
		||||
            model_path,
 | 
			
		||||
            device=device,
 | 
			
		||||
            num_gpus=num_gpus,
 | 
			
		||||
            max_gpu_memory=max_gpu_memory,
 | 
			
		||||
            load_8bit=load_8bit,
 | 
			
		||||
            cpu_offloading=cpu_offloading,
 | 
			
		||||
            gptq_config=gptq_config,
 | 
			
		||||
            awq_config=awq_config,
 | 
			
		||||
        )
 | 
			
		||||
        self.device = device
 | 
			
		||||
        if self.tokenizer.pad_token is None:
 | 
			
		||||
            self.tokenizer.pad_token = self.tokenizer.eos_token
 | 
			
		||||
        self.context_len = get_context_length(self.model.config)
 | 
			
		||||
        self.generate_stream_func = get_generate_stream_function(self.model, model_path)
 | 
			
		||||
        self.stream_interval = stream_interval
 | 
			
		||||
 | 
			
		||||
        if not no_register:
 | 
			
		||||
            self.init_heart_beat()
 | 
			
		||||
 | 
			
		||||
    def generate_stream_gate(self, params):
 | 
			
		||||
        self.call_ct += 1
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            for output in self.generate_stream_func(
 | 
			
		||||
                self.model,
 | 
			
		||||
                self.tokenizer,
 | 
			
		||||
                params,
 | 
			
		||||
                self.device,
 | 
			
		||||
                self.context_len,
 | 
			
		||||
                self.stream_interval,
 | 
			
		||||
            ):
 | 
			
		||||
                ret = {
 | 
			
		||||
                    "text": output["text"],
 | 
			
		||||
                    "error_code": 0,
 | 
			
		||||
                }
 | 
			
		||||
                if "usage" in output:
 | 
			
		||||
                    ret["usage"] = output["usage"]
 | 
			
		||||
                if "finish_reason" in output:
 | 
			
		||||
                    ret["finish_reason"] = output["finish_reason"]
 | 
			
		||||
                if "logprobs" in output:
 | 
			
		||||
                    ret["logprobs"] = output["logprobs"]
 | 
			
		||||
                yield json.dumps(ret).encode() + b"\0"
 | 
			
		||||
        except torch.cuda.OutOfMemoryError as e:
 | 
			
		||||
            ret = {
 | 
			
		||||
                "text": f"{SERVER_ERROR_MSG}\n\n({e})",
 | 
			
		||||
                "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
 | 
			
		||||
            }
 | 
			
		||||
            yield json.dumps(ret).encode() + b"\0"
 | 
			
		||||
        except (ValueError, RuntimeError) as e:
 | 
			
		||||
            ret = {
 | 
			
		||||
                "text": f"{SERVER_ERROR_MSG}\n\n({e})",
 | 
			
		||||
                "error_code": ErrorCode.INTERNAL_ERROR,
 | 
			
		||||
            }
 | 
			
		||||
            yield json.dumps(ret).encode() + b"\0"
 | 
			
		||||
 | 
			
		||||
    def generate_gate(self, params):
 | 
			
		||||
        for x in self.generate_stream_gate(params):
 | 
			
		||||
            pass
 | 
			
		||||
        return json.loads(x[:-1].decode())
 | 
			
		||||
 | 
			
		||||
    @torch.inference_mode()
 | 
			
		||||
    def get_embeddings(self, params):
 | 
			
		||||
        self.call_ct += 1
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            tokenizer = self.tokenizer
 | 
			
		||||
            is_llama = "llama" in str(
 | 
			
		||||
                type(self.model)
 | 
			
		||||
            )  # llama supports batch inference
 | 
			
		||||
            is_chatglm = "chatglm" in str(type(self.model))
 | 
			
		||||
            is_t5 = "t5" in str(type(self.model))
 | 
			
		||||
            is_bert = "bert" in str(type(self.model))
 | 
			
		||||
 | 
			
		||||
            if is_llama:
 | 
			
		||||
                encoding = tokenizer.batch_encode_plus(
 | 
			
		||||
                    params["input"], padding=True, return_tensors="pt"
 | 
			
		||||
                )
 | 
			
		||||
                input_ids = encoding["input_ids"].to(self.device)
 | 
			
		||||
                attention_mask = encoding["attention_mask"].to(self.device)
 | 
			
		||||
                model_output = self.model(
 | 
			
		||||
                    input_ids, attention_mask, output_hidden_states=True
 | 
			
		||||
                )
 | 
			
		||||
                data = model_output.hidden_states[-1]
 | 
			
		||||
                mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
 | 
			
		||||
                masked_embeddings = data * mask
 | 
			
		||||
                sum_embeddings = torch.sum(masked_embeddings, dim=1)
 | 
			
		||||
                seq_length = torch.sum(mask, dim=1)
 | 
			
		||||
                embedding = sum_embeddings / seq_length
 | 
			
		||||
                normalized_embeddings = F.normalize(embedding, p=2, dim=1)
 | 
			
		||||
                ret = {
 | 
			
		||||
                    "embedding": normalized_embeddings.tolist(),
 | 
			
		||||
                    "token_num": torch.sum(attention_mask).item(),
 | 
			
		||||
                }
 | 
			
		||||
            elif is_bert:
 | 
			
		||||
                embedding = []
 | 
			
		||||
                token_num = 0
 | 
			
		||||
                for text in params["input"]:
 | 
			
		||||
                    input_ids = tokenizer.encode(text, return_tensors="pt").to(
 | 
			
		||||
                        self.device
 | 
			
		||||
                    )
 | 
			
		||||
                    model_output = self.model(input_ids)
 | 
			
		||||
                    data = model_output[0][:, 0]
 | 
			
		||||
                    data = F.normalize(torch.mean(data, dim=0), p=2, dim=0)
 | 
			
		||||
                    embedding.append(data.tolist())
 | 
			
		||||
                    token_num += len(input_ids[0])
 | 
			
		||||
                ret = {
 | 
			
		||||
                    "embedding": embedding,
 | 
			
		||||
                    "token_num": token_num,
 | 
			
		||||
                }
 | 
			
		||||
            else:
 | 
			
		||||
                embedding = []
 | 
			
		||||
                token_num = 0
 | 
			
		||||
                for text in params["input"]:
 | 
			
		||||
                    input_ids = tokenizer.encode(text, return_tensors="pt").to(
 | 
			
		||||
                        self.device
 | 
			
		||||
                    )
 | 
			
		||||
                    if is_t5:
 | 
			
		||||
                        model_output = self.model(
 | 
			
		||||
                            input_ids, decoder_input_ids=input_ids
 | 
			
		||||
                        )
 | 
			
		||||
                    else:
 | 
			
		||||
                        model_output = self.model(input_ids, output_hidden_states=True)
 | 
			
		||||
                    if is_chatglm:
 | 
			
		||||
                        data = (model_output.hidden_states[-1].transpose(0, 1))[0]
 | 
			
		||||
                    elif is_t5:
 | 
			
		||||
                        data = model_output.encoder_last_hidden_state[0]
 | 
			
		||||
                    else:
 | 
			
		||||
                        data = model_output.hidden_states[-1][0]
 | 
			
		||||
                    data = F.normalize(torch.mean(data, dim=0), p=2, dim=0)
 | 
			
		||||
                    embedding.append(data.tolist())
 | 
			
		||||
                    token_num += len(input_ids[0])
 | 
			
		||||
                ret = {
 | 
			
		||||
                    "embedding": embedding,
 | 
			
		||||
                    "token_num": token_num,
 | 
			
		||||
                }
 | 
			
		||||
        except torch.cuda.OutOfMemoryError as e:
 | 
			
		||||
            ret = {
 | 
			
		||||
                "text": f"{SERVER_ERROR_MSG}\n\n({e})",
 | 
			
		||||
                "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
 | 
			
		||||
            }
 | 
			
		||||
        except (ValueError, RuntimeError) as e:
 | 
			
		||||
            ret = {
 | 
			
		||||
                "text": f"{SERVER_ERROR_MSG}\n\n({e})",
 | 
			
		||||
                "error_code": ErrorCode.INTERNAL_ERROR,
 | 
			
		||||
            }
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def release_worker_semaphore():
 | 
			
		||||
    worker.semaphore.release()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def acquire_worker_semaphore():
 | 
			
		||||
    if worker.semaphore is None:
 | 
			
		||||
        worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
 | 
			
		||||
    return worker.semaphore.acquire()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_background_tasks():
 | 
			
		||||
    background_tasks = BackgroundTasks()
 | 
			
		||||
    background_tasks.add_task(release_worker_semaphore)
 | 
			
		||||
    return background_tasks
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/worker_generate_stream")
 | 
			
		||||
async def api_generate_stream(request: Request):
 | 
			
		||||
    params = await request.json()
 | 
			
		||||
    await acquire_worker_semaphore()
 | 
			
		||||
    generator = worker.generate_stream_gate(params)
 | 
			
		||||
    background_tasks = create_background_tasks()
 | 
			
		||||
    return StreamingResponse(generator, background=background_tasks)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/worker_generate")
 | 
			
		||||
async def api_generate(request: Request):
 | 
			
		||||
    params = await request.json()
 | 
			
		||||
    await acquire_worker_semaphore()
 | 
			
		||||
    output = worker.generate_gate(params)
 | 
			
		||||
    release_worker_semaphore()
 | 
			
		||||
    return JSONResponse(output)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/worker_get_embeddings")
 | 
			
		||||
async def api_get_embeddings(request: Request):
 | 
			
		||||
    params = await request.json()
 | 
			
		||||
    await acquire_worker_semaphore()
 | 
			
		||||
    embedding = worker.get_embeddings(params)
 | 
			
		||||
    release_worker_semaphore()
 | 
			
		||||
    return JSONResponse(content=embedding)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/worker_get_status")
 | 
			
		||||
async def api_get_status(request: Request):
 | 
			
		||||
    return worker.get_status()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/count_token")
 | 
			
		||||
async def api_count_token(request: Request):
 | 
			
		||||
    params = await request.json()
 | 
			
		||||
    return worker.count_token(params)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/worker_get_conv_template")
 | 
			
		||||
async def api_get_conv(request: Request):
 | 
			
		||||
    return worker.get_conv_template()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/model_details")
 | 
			
		||||
async def api_model_details(request: Request):
 | 
			
		||||
    return {"context_length": worker.context_len}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    patch_fastchat()
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--host", type=str, default="localhost")
 | 
			
		||||
    parser.add_argument("--port", type=int, default=21002)
 | 
			
		||||
    parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--controller-address", type=str, default="http://localhost:21001"
 | 
			
		||||
    )
 | 
			
		||||
    add_model_args(parser)
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model-names",
 | 
			
		||||
        type=lambda s: s.split(","),
 | 
			
		||||
        help="Optional display comma separated names",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--conv-template", type=str, default=None, help="Conversation prompt template."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--limit-worker-concurrency",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=5,
 | 
			
		||||
        help="Limit the model concurrency to prevent OOM.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--stream-interval", type=int, default=2)
 | 
			
		||||
    parser.add_argument("--no-register", action="store_true")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    logger.info(f"args: {args}")
 | 
			
		||||
 | 
			
		||||
    if args.gpus:
 | 
			
		||||
        invalidInputError(len(args.gpus.split(",")) > args.num_gpus, f"Larger --num-gpus "
 | 
			
		||||
                          "({args.num_gpus}) than --gpus {args.gpus}!")
 | 
			
		||||
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
 | 
			
		||||
 | 
			
		||||
    gptq_config = GptqConfig(
 | 
			
		||||
        ckpt=args.gptq_ckpt or args.model_path,
 | 
			
		||||
        wbits=args.gptq_wbits,
 | 
			
		||||
        groupsize=args.gptq_groupsize,
 | 
			
		||||
        act_order=args.gptq_act_order,
 | 
			
		||||
    )
 | 
			
		||||
    awq_config = AWQConfig(
 | 
			
		||||
        ckpt=args.awq_ckpt or args.model_path,
 | 
			
		||||
        wbits=args.awq_wbits,
 | 
			
		||||
        groupsize=args.awq_groupsize,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    worker = ModelWorker(
 | 
			
		||||
        args.controller_address,
 | 
			
		||||
        args.worker_address,
 | 
			
		||||
        worker_id,
 | 
			
		||||
        args.model_path,
 | 
			
		||||
        args.model_names,
 | 
			
		||||
        args.limit_worker_concurrency,
 | 
			
		||||
        no_register=args.no_register,
 | 
			
		||||
        device=args.device,
 | 
			
		||||
        num_gpus=args.num_gpus,
 | 
			
		||||
        max_gpu_memory=args.max_gpu_memory,
 | 
			
		||||
        load_8bit=args.load_8bit,
 | 
			
		||||
        cpu_offloading=args.cpu_offloading,
 | 
			
		||||
        gptq_config=gptq_config,
 | 
			
		||||
        awq_config=awq_config,
 | 
			
		||||
        stream_interval=args.stream_interval,
 | 
			
		||||
        conv_template=args.conv_template,
 | 
			
		||||
    )
 | 
			
		||||
    uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 | 
			
		||||
		Loading…
	
		Reference in a new issue