Adding load_low_bit interface for ipex_llm_worker (#11000)

* initial implementation, need tests

* fix

* fix baichuan issue

* fix typo
This commit is contained in:
Guancheng Fu 2024-05-13 15:30:19 +08:00 committed by GitHub
parent 1b3c7a6928
commit 74997a3ed1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 81 additions and 38 deletions

View file

@ -46,7 +46,7 @@ pip install --pre --upgrade ipex-llm[xpu,serving] --extra-index-url https://pyto
You need first run the fastchat controller
```bash
python3 -m fastchat.serve.controller
python -m fastchat.serve.controller
```
### Launch model worker(s) and load models
@ -63,14 +63,22 @@ To run the `ipex_llm_worker` on CPU, using the following code:
source ipex-llm-init -t
# Available low_bit format including sym_int4, sym_int8, bf16 etc.
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "cpu"
python -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "cpu"
```
For GPU example:
```bash
# Available low_bit format including sym_int4, sym_int8, fp16 etc.
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "xpu"
python -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "xpu"
```
We have also provided an option `--load-low-bit-model` to load models that have been converted and saved into disk using the `save_low_bit` interface as introduced in this [document](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/example/CPU/HF-Transformers-AutoModels/Save-Load/README.md).
Check the following examples:
```bash
# Or --device "cpu"
python -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path /Low/Bit/Model/Path --trust-remote-code --device "xpu"
```
#### For self-speculative decoding example:
@ -80,14 +88,14 @@ You can use IPEX-LLM to run `self-speculative decoding` example. Refer to [here]
```bash
# Available low_bit format only including bf16 on CPU.
source ipex-llm-init -t
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "bf16" --trust-remote-code --device "cpu" --speculative
python -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "bf16" --trust-remote-code --device "cpu" --speculative
# Available low_bit format only including fp16 on GPU.
source /opt/intel/oneapi/setvars.sh
export ENABLE_SDP_FUSION=1
export SYCL_CACHE_PERSISTENT=1
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "fp16" --trust-remote-code --device "xpu" --speculative
python -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "fp16" --trust-remote-code --device "xpu" --speculative
```
For a full list of accepted arguments, you can refer to the main method of the `ipex_llm_worker.py`
@ -100,16 +108,16 @@ To run using the `vLLM_worker`, we don't need to change model name, just simply
```bash
# On CPU
python3 -m ipex_llm.serving.fastchat.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device cpu
python -m ipex_llm.serving.fastchat.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device cpu
# On GPU
python3 -m ipex_llm.serving.fastchat.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device xpu
python -m ipex_llm.serving.fastchat.vllm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --device xpu
```
### Launch Gradio web server
```bash
python3 -m fastchat.serve.gradio_web_server
python -m fastchat.serve.gradio_web_server
```
This is the user interface that users will interact with.
@ -121,5 +129,5 @@ By following these steps, you will be able to serve your models using the web UI
To start an OpenAI API server that provides compatible APIs using IPEX-LLM backend, you can launch the `openai_api_server` and follow this [doc](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md) to use it.
```bash
python3 -m fastchat.serve.openai_api_server --host localhost --port 8000
python -m fastchat.serve.openai_api_server --host localhost --port 8000
```

View file

@ -69,6 +69,7 @@ class BigDLLLMWorker(BaseModelWorker):
trust_remote_code: bool = False,
embed_in_truncate: bool = False,
speculative: bool = False,
load_low_bit_model: bool = False,
stream_interval: int = 4,
):
super().__init__(
@ -82,6 +83,7 @@ class BigDLLLMWorker(BaseModelWorker):
)
self.load_in_low_bit = load_in_low_bit
self.load_low_bit_model = load_low_bit_model
logger.info(
f"Loading the model {self.model_names} on worker {worker_id},"
f" worker type: BigDLLLM worker..."
@ -94,7 +96,12 @@ class BigDLLLMWorker(BaseModelWorker):
self.device = device
self.speculative = speculative
self.model, self.tokenizer = load_model(
model_path, device, self.load_in_low_bit, trust_remote_code, speculative
model_path,
device,
self.load_in_low_bit,
trust_remote_code,
speculative,
load_low_bit_model,
)
self.stream_interval = stream_interval
self.context_len = get_context_length(self.model.config)
@ -495,6 +502,12 @@ if __name__ == "__main__":
help="Trust remote code (e.g., from HuggingFace) when"
"downloading the model and tokenizer.",
)
parser.add_argument(
"--load-low-bit-model",
action="store_true",
default=False,
help="Load models that have been converted/saved using ipex-llm's save_low_bit interface",
)
parser.add_argument("--embed-in-truncate", action="store_true")
args = parser.parse_args()
@ -512,5 +525,7 @@ if __name__ == "__main__":
args.trust_remote_code,
args.embed_in_truncate,
args.speculative,
args.load_low_bit_model,
args.stream_interval,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View file

@ -3,13 +3,14 @@ repo_id:
# - 'THUDM/chatglm-6b'
# - 'THUDM/chatglm2-6b'
- 'meta-llama/Llama-2-7b-chat-hf'
- 'baichuan-inc/Baichuan2-7B-Chat'
- 'Qwen/Qwen-7B-Chat'
# - 'baichuan-inc/Baichuan2-7B-Chat'
# - 'Qwen/Qwen-7B-Chat'
# - 'liuhaotian/llava-v1.5-7b' # requires a LLAVA_REPO_DIR env variables pointing to the llava dir; added only for gpu win related test_api now
local_model_hub: 'path to your local model hub'
local_model_hub: '/mnt/disk1/models'
low_bit:
- 'sym_int4' # default to use 'sym_int4' (i.e. symmetric int4)
- 'bf16'
device:
- 'cpu'
# - 'xpu'
#- 'cpu'
- 'xpu'
load_low_bit_model: False

View file

@ -46,6 +46,7 @@ def load_model(
low_bit: str = 'sym_int4',
trust_remote_code: bool = True,
speculative: bool = False,
load_low_bit_model: bool = False,
):
"""Load a model using BigDL LLM backend."""
@ -53,26 +54,38 @@ def load_model(
invalidInputError(device == 'cpu' or device == 'xpu',
"BigDL-LLM only supports device cpu or xpu")
tokenizer_cls = get_tokenizer_cls(model_path)
model_cls = get_model_cls(model_path, low_bit)
model_kwargs = {"use_cache": True}
if trust_remote_code:
model_kwargs["trust_remote_code"] = True
if low_bit == "bf16":
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": torch.bfloat16})
elif low_bit == "fp16":
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": torch.float16})
else:
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": 'auto'})
# Load tokenizer
tokenizer_cls = get_tokenizer_cls(model_path)
model_kwargs = {"use_cache": True}
if speculative:
invalidInputError(load_low_bit_model is not True,
"Self-Speculative currently do not support load low-bit format models")
invalidInputError(low_bit == "fp16" or low_bit == "bf16",
"Self-Speculative only supports low_bit fp16 or bf16")
model_kwargs["speculative"] = True
# Load tokenizer
tokenizer = tokenizer_cls.from_pretrained(model_path, trust_remote_code=True)
model = model_cls.from_pretrained(model_path, **model_kwargs)
if trust_remote_code:
model_kwargs["trust_remote_code"] = True
if load_low_bit_model:
# After save_low_bit, the from_pretrained interface does not accept trust_remote_code=True
tokenizer = tokenizer_cls.from_pretrained(model_path)
model = model_cls.load_low_bit(model_path, **model_kwargs)
else:
if trust_remote_code:
tokenizer = tokenizer_cls.from_pretrained(model_path, trust_remote_code=True)
else:
tokenizer = tokenizer_cls.from_pretrained(model_path)
if low_bit == "bf16":
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": torch.bfloat16})
elif low_bit == "fp16":
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": torch.float16})
else:
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": 'auto'})
model = model_cls.from_pretrained(model_path, **model_kwargs)
if not get_enable_ipex():
model = model.eval()
@ -83,13 +96,14 @@ def load_model(
return model, tokenizer
def try_run_test_generation(local_model_hub, model_path, device, low_bit):
def try_run_test_generation(local_model_hub, model_path, device, low_bit, load_low_bit_model):
path = get_model_path(model_path, local_model_hub)
try:
run_test_generation(path, device, low_bit)
run_test_generation(path, device, low_bit, load_low_bit_model)
except:
print(f"Loading model failed for model {model_path} \
with device:{device} and low_bit:{low_bit}")
with device:{device} and low_bit:{low_bit} \
and load_low_bit_model {load_low_bit_model}")
return "False"
return "True"
@ -105,11 +119,11 @@ def get_model_path(repo_id, local_model_hub):
return repo_id
def run_test_generation(model_path, device, low_bit):
model, tokenizer = load_model(model_path, device, low_bit, True)
def run_test_generation(model_path, device, low_bit, load_low_bit_model):
# Disable speculative by default
model, tokenizer = load_model(model_path, device, low_bit, True, False, load_low_bit_model)
with torch.inference_mode():
prompt = "What is AI?"
# TODO: if gpu, will need to move the tensor to xpu
input_ids = tokenizer.encode(prompt, return_tensors="pt")
if device == 'xpu':
input_ids = input_ids.to('xpu')
@ -133,7 +147,6 @@ def run_test_generation(model_path, device, low_bit):
# Note that this only test loading models instead of generation correctness
if __name__ == '__main__':
import os
# TODO: move config.yaml to a different folder
current_dir = os.path.dirname(os.path.realpath(__file__))
results = []
from omegaconf import OmegaConf
@ -144,9 +157,15 @@ if __name__ == '__main__':
for model in conf.repo_id:
for low_bit in conf.low_bit:
for device in conf.device:
result = try_run_test_generation(conf['local_model_hub'], model, device, low_bit)
results.append([model, device, low_bit, result])
result = try_run_test_generation(conf['local_model_hub'],
model,
device,
low_bit,
conf["load_low_bit_model"]
)
results.append([model, device, low_bit, conf["load_low_bit_model"], result])
df = pd.DataFrame(results, columns=['model', 'device', 'low_bit', 'result'])
df = pd.DataFrame(results,
columns=['model', 'device', 'low_bit', 'use_low_bit_model', 'result'])
df.to_csv(csv_name)
results = []