LLM: Enable Speculative on Fastchat (#10909)

* init

* enable streamer

* update

* update

* remove deprecated

* update

* update

* add gpu example
This commit is contained in:
Wang, Jian4 2024-05-06 10:06:20 +08:00 committed by GitHub
parent 8379f02a74
commit 0e0bd309e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 66 additions and 41 deletions

View file

@ -61,6 +61,23 @@ export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --low-bit "sym_int4" --trust-remote-code --device "xpu"
```
#### For self-speculative decoding example:
You can use IPEX-LLM to run `self-speculative decoding` example. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel MAX GPUs. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel CPUs.
```bash
# Available low_bit format only including bf16 on CPU.
source ipex-llm-init -t
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "bf16" --trust-remote-code --device "cpu" --speculative
# Available low_bit format only including fp16 on GPU.
source /opt/intel/oneapi/setvars.sh
export ENABLE_SDP_FUSION=1
export SYCL_CACHE_PERSISTENT=1
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "fp16" --trust-remote-code --device "xpu" --speculative
```
You can get output like this:
```bash

View file

@ -25,9 +25,11 @@ You may install **`ipex-llm`** with `FastChat` as follows:
```bash
pip install --pre --upgrade ipex-llm[serving]
pip install transformers==4.36.0
# Or
pip install --pre --upgrade ipex-llm[all]
```
To add GPU support for FastChat, you may install **`ipex-llm`** as follows:
@ -51,39 +53,6 @@ python3 -m fastchat.serve.controller
Using IPEX-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat.
#### IPEX-LLM model worker (deprecated)
> Warning: This method has been deprecated, please change to use `IPEX-LLM` [worker](#ipex-llm-worker) instead.
FastChat determines the Model adapter to use through path matching. Therefore, in order to load models using IPEX-LLM, you need to make some modifications to the model's name.
For instance, assuming you have downloaded the `llama-7b-hf` from [HuggingFace](https://huggingface.co/decapoda-research/llama-7b-hf). Then, to use the `IPEX-LLM` as backend, you need to change name from `llama-7b-hf` to `ipex-llm-7b`.The key point here is that the model's path should include "ipex" and **should not include paths matched by other model adapters**.
Then we will use `ipex-llm-7b` as model-path.
> note: This is caused by the priority of name matching list. The new added `IPEX-LLM` adapter is at the tail of the name-matching list so that it has the lowest priority. If model path contains other keywords like `vicuna` which matches to another adapter with higher priority, then the `IPEX-LLM` adapter will not work.
A special case is `ChatGLM` models. For these models, you do not need to do any changes after downloading the model and the `IPEX-LLM` backend will be used automatically.
Then we can run model workers
```bash
# On CPU
python3 -m ipex_llm.serving.fastchat.model_worker --model-path PATH/TO/ipex-llm-7b --device cpu
# On GPU
python3 -m ipex_llm.serving.fastchat.model_worker --model-path PATH/TO/ipex-llm-7b --device xpu
```
If you run successfully using `ipex_llm` backend, you can see the output in log like this:
```bash
INFO - Converting the current model to sym_int4 format......
```
> note: We currently only support int4 quantization for this method.
</details>
#### IPEX-LLM worker
To integrate IPEX-LLM with `FastChat` efficiently, we have provided a new model_worker implementation named `ipex_llm_worker.py`.
@ -104,6 +73,23 @@ For GPU example:
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "xpu"
```
#### For self-speculative decoding example:
You can use IPEX-LLM to run `self-speculative decoding` example. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel MAX GPUs. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel CPUs.
```bash
# Available low_bit format only including bf16 on CPU.
source ipex-llm-init -t
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "bf16" --trust-remote-code --device "cpu" --speculative
# Available low_bit format only including fp16 on GPU.
source /opt/intel/oneapi/setvars.sh
export ENABLE_SDP_FUSION=1
export SYCL_CACHE_PERSISTENT=1
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "fp16" --trust-remote-code --device "xpu" --speculative
```
For a full list of accepted arguments, you can refer to the main method of the `ipex_llm_worker.py`
#### IPEX-LLM vLLM worker

View file

@ -63,6 +63,7 @@ class BigDLLLMWorker(BaseModelWorker):
device: str = "cpu",
no_register: bool = False,
trust_remote_code: bool = False,
speculative: bool = False,
stream_interval: int = 4,
):
super().__init__(
@ -82,11 +83,13 @@ class BigDLLLMWorker(BaseModelWorker):
)
logger.info(f"Using low bit format: {self.load_in_low_bit}, device: {device}")
if speculative:
logger.info(f"Using Self-Speculative decoding to generate")
self.device = device
self.speculative = speculative
self.model, self.tokenizer = load_model(
model_path, device, self.load_in_low_bit, trust_remote_code
model_path, device, self.load_in_low_bit, trust_remote_code, speculative
)
self.stream_interval = stream_interval
self.context_len = get_context_length(self.model.config)
@ -98,6 +101,7 @@ class BigDLLLMWorker(BaseModelWorker):
# context length is self.context_length
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
do_sample = bool(params.get("do_sample", False))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", 1))
@ -165,6 +169,7 @@ class BigDLLLMWorker(BaseModelWorker):
streamer=streamer,
temperature=temperature,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
)
@ -314,6 +319,12 @@ if __name__ == "__main__":
parser.add_argument(
"--device", type=str, default="cpu", help="Device for executing model, cpu/xpu"
)
parser.add_argument(
"--speculative",
action="store_true",
default=False,
help="To use self-speculative or not",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
@ -335,5 +346,6 @@ if __name__ == "__main__":
args.device,
args.no_register,
args.trust_remote_code,
args.speculative,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View file

@ -45,6 +45,7 @@ def load_model(
device: str = "cpu",
low_bit: str = 'sym_int4',
trust_remote_code: bool = True,
speculative: bool = False,
):
"""Load a model using BigDL LLM backend."""
@ -64,6 +65,11 @@ def load_model(
else:
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": 'auto'})
if speculative:
invalidInputError(low_bit == "fp16" or low_bit == "bf16",
"Self-Speculative only supports low_bit fp16 or bf16")
model_kwargs["speculative"] = True
# Load tokenizer
tokenizer = tokenizer_cls.from_pretrained(model_path, trust_remote_code=True)
model = model_cls.from_pretrained(model_path, **model_kwargs)

View file

@ -97,6 +97,7 @@ def generate(
new_speculative_kwargs[var] = value
return self.speculative_generate(inputs=inputs,
draft_model=self.draft_model,
streamer=streamer,
**new_speculative_kwargs)
else:
# When `draft_model` is false, these attributes
@ -512,7 +513,7 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
return past_key_values
def _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs):
def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sampling_kwargs):
if generation_config is None:
generation_config = self.generation_config
@ -591,8 +592,8 @@ def _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs):
# 5. Prepare `input_ids` which will be used for auto-regressive generation
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
# if streamer is not None:
# streamer.put(input_ids.cpu())
if streamer is not None:
streamer.put(input_ids.cpu())
input_ids_length = input_ids.shape[-1]
@ -658,6 +659,7 @@ def speculative_generate(self,
min_step_draft=3,
generation_config: Optional[GenerationConfig] = None,
attention_mask=None,
streamer: Optional["BaseStreamer"] = None,
**sampling_kwargs):
invalidInputError(draft_model is not None,
"Draft model should be provided.")
@ -666,7 +668,7 @@ def speculative_generate(self,
min_step_draft = min_step_draft if min_step_draft >= 1 else 1
input_ids, generation_config, logits_processor, stopping_criteria, \
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
model_kwargs = _prepare_generate_args(self, inputs, generation_config, streamer,
**sampling_kwargs)
step = 0
@ -1061,7 +1063,8 @@ def speculative_generate(self,
generate_ids[:, step:step+output_ids.size(1)] = output_ids
current_input_ids = output_ids[:, -1:]
if streamer is not None:
streamer.put(output_ids.cpu())
step += output_ids.size(1)
# remove one generated by the base model
@ -1094,7 +1097,8 @@ def speculative_generate(self,
idx = output_ids_list.index(generation_config.eos_token_id)
step -= (len(output_ids_list) - idx - 1)
break
if streamer is not None:
streamer.end()
step = min(step, max_new_tokens)
e2e_toc = time.time()
self.n_token_generated = step