LLM: Enable Speculative on Fastchat (#10909)
* init * enable streamer * update * update * remove deprecated * update * update * add gpu example
This commit is contained in:
parent
8379f02a74
commit
0e0bd309e2
5 changed files with 66 additions and 41 deletions
|
|
@ -61,6 +61,23 @@ export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||||
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --low-bit "sym_int4" --trust-remote-code --device "xpu"
|
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --low-bit "sym_int4" --trust-remote-code --device "xpu"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### For self-speculative decoding example:
|
||||||
|
|
||||||
|
You can use IPEX-LLM to run `self-speculative decoding` example. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel MAX GPUs. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel CPUs.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Available low_bit format only including bf16 on CPU.
|
||||||
|
source ipex-llm-init -t
|
||||||
|
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "bf16" --trust-remote-code --device "cpu" --speculative
|
||||||
|
|
||||||
|
# Available low_bit format only including fp16 on GPU.
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
export ENABLE_SDP_FUSION=1
|
||||||
|
export SYCL_CACHE_PERSISTENT=1
|
||||||
|
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||||
|
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "fp16" --trust-remote-code --device "xpu" --speculative
|
||||||
|
```
|
||||||
|
|
||||||
You can get output like this:
|
You can get output like this:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
||||||
|
|
@ -25,9 +25,11 @@ You may install **`ipex-llm`** with `FastChat` as follows:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install --pre --upgrade ipex-llm[serving]
|
pip install --pre --upgrade ipex-llm[serving]
|
||||||
|
pip install transformers==4.36.0
|
||||||
|
|
||||||
# Or
|
# Or
|
||||||
pip install --pre --upgrade ipex-llm[all]
|
pip install --pre --upgrade ipex-llm[all]
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
To add GPU support for FastChat, you may install **`ipex-llm`** as follows:
|
To add GPU support for FastChat, you may install **`ipex-llm`** as follows:
|
||||||
|
|
@ -51,39 +53,6 @@ python3 -m fastchat.serve.controller
|
||||||
|
|
||||||
Using IPEX-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat.
|
Using IPEX-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat.
|
||||||
|
|
||||||
#### IPEX-LLM model worker (deprecated)
|
|
||||||
|
|
||||||
> Warning: This method has been deprecated, please change to use `IPEX-LLM` [worker](#ipex-llm-worker) instead.
|
|
||||||
|
|
||||||
FastChat determines the Model adapter to use through path matching. Therefore, in order to load models using IPEX-LLM, you need to make some modifications to the model's name.
|
|
||||||
|
|
||||||
For instance, assuming you have downloaded the `llama-7b-hf` from [HuggingFace](https://huggingface.co/decapoda-research/llama-7b-hf). Then, to use the `IPEX-LLM` as backend, you need to change name from `llama-7b-hf` to `ipex-llm-7b`.The key point here is that the model's path should include "ipex" and **should not include paths matched by other model adapters**.
|
|
||||||
|
|
||||||
Then we will use `ipex-llm-7b` as model-path.
|
|
||||||
|
|
||||||
> note: This is caused by the priority of name matching list. The new added `IPEX-LLM` adapter is at the tail of the name-matching list so that it has the lowest priority. If model path contains other keywords like `vicuna` which matches to another adapter with higher priority, then the `IPEX-LLM` adapter will not work.
|
|
||||||
|
|
||||||
A special case is `ChatGLM` models. For these models, you do not need to do any changes after downloading the model and the `IPEX-LLM` backend will be used automatically.
|
|
||||||
|
|
||||||
Then we can run model workers
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# On CPU
|
|
||||||
python3 -m ipex_llm.serving.fastchat.model_worker --model-path PATH/TO/ipex-llm-7b --device cpu
|
|
||||||
|
|
||||||
# On GPU
|
|
||||||
python3 -m ipex_llm.serving.fastchat.model_worker --model-path PATH/TO/ipex-llm-7b --device xpu
|
|
||||||
```
|
|
||||||
|
|
||||||
If you run successfully using `ipex_llm` backend, you can see the output in log like this:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
INFO - Converting the current model to sym_int4 format......
|
|
||||||
```
|
|
||||||
|
|
||||||
> note: We currently only support int4 quantization for this method.
|
|
||||||
</details>
|
|
||||||
|
|
||||||
#### IPEX-LLM worker
|
#### IPEX-LLM worker
|
||||||
|
|
||||||
To integrate IPEX-LLM with `FastChat` efficiently, we have provided a new model_worker implementation named `ipex_llm_worker.py`.
|
To integrate IPEX-LLM with `FastChat` efficiently, we have provided a new model_worker implementation named `ipex_llm_worker.py`.
|
||||||
|
|
@ -104,6 +73,23 @@ For GPU example:
|
||||||
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "xpu"
|
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "xpu"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### For self-speculative decoding example:
|
||||||
|
|
||||||
|
You can use IPEX-LLM to run `self-speculative decoding` example. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel MAX GPUs. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel CPUs.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Available low_bit format only including bf16 on CPU.
|
||||||
|
source ipex-llm-init -t
|
||||||
|
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "bf16" --trust-remote-code --device "cpu" --speculative
|
||||||
|
|
||||||
|
# Available low_bit format only including fp16 on GPU.
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
export ENABLE_SDP_FUSION=1
|
||||||
|
export SYCL_CACHE_PERSISTENT=1
|
||||||
|
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||||
|
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "fp16" --trust-remote-code --device "xpu" --speculative
|
||||||
|
```
|
||||||
|
|
||||||
For a full list of accepted arguments, you can refer to the main method of the `ipex_llm_worker.py`
|
For a full list of accepted arguments, you can refer to the main method of the `ipex_llm_worker.py`
|
||||||
|
|
||||||
#### IPEX-LLM vLLM worker
|
#### IPEX-LLM vLLM worker
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,7 @@ class BigDLLLMWorker(BaseModelWorker):
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
no_register: bool = False,
|
no_register: bool = False,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
speculative: bool = False,
|
||||||
stream_interval: int = 4,
|
stream_interval: int = 4,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
@ -82,11 +83,13 @@ class BigDLLLMWorker(BaseModelWorker):
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Using low bit format: {self.load_in_low_bit}, device: {device}")
|
logger.info(f"Using low bit format: {self.load_in_low_bit}, device: {device}")
|
||||||
|
if speculative:
|
||||||
|
logger.info(f"Using Self-Speculative decoding to generate")
|
||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.speculative = speculative
|
||||||
self.model, self.tokenizer = load_model(
|
self.model, self.tokenizer = load_model(
|
||||||
model_path, device, self.load_in_low_bit, trust_remote_code
|
model_path, device, self.load_in_low_bit, trust_remote_code, speculative
|
||||||
)
|
)
|
||||||
self.stream_interval = stream_interval
|
self.stream_interval = stream_interval
|
||||||
self.context_len = get_context_length(self.model.config)
|
self.context_len = get_context_length(self.model.config)
|
||||||
|
|
@ -98,6 +101,7 @@ class BigDLLLMWorker(BaseModelWorker):
|
||||||
# context length is self.context_length
|
# context length is self.context_length
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
temperature = float(params.get("temperature", 1.0))
|
temperature = float(params.get("temperature", 1.0))
|
||||||
|
do_sample = bool(params.get("do_sample", False))
|
||||||
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
||||||
top_p = float(params.get("top_p", 1.0))
|
top_p = float(params.get("top_p", 1.0))
|
||||||
top_k = int(params.get("top_k", 1))
|
top_k = int(params.get("top_k", 1))
|
||||||
|
|
@ -165,6 +169,7 @@ class BigDLLLMWorker(BaseModelWorker):
|
||||||
streamer=streamer,
|
streamer=streamer,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
|
do_sample=do_sample,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
)
|
)
|
||||||
|
|
@ -314,6 +319,12 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device", type=str, default="cpu", help="Device for executing model, cpu/xpu"
|
"--device", type=str, default="cpu", help="Device for executing model, cpu/xpu"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--speculative",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="To use self-speculative or not",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|
@ -335,5 +346,6 @@ if __name__ == "__main__":
|
||||||
args.device,
|
args.device,
|
||||||
args.no_register,
|
args.no_register,
|
||||||
args.trust_remote_code,
|
args.trust_remote_code,
|
||||||
|
args.speculative,
|
||||||
)
|
)
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ def load_model(
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
low_bit: str = 'sym_int4',
|
low_bit: str = 'sym_int4',
|
||||||
trust_remote_code: bool = True,
|
trust_remote_code: bool = True,
|
||||||
|
speculative: bool = False,
|
||||||
):
|
):
|
||||||
"""Load a model using BigDL LLM backend."""
|
"""Load a model using BigDL LLM backend."""
|
||||||
|
|
||||||
|
|
@ -64,6 +65,11 @@ def load_model(
|
||||||
else:
|
else:
|
||||||
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": 'auto'})
|
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": 'auto'})
|
||||||
|
|
||||||
|
if speculative:
|
||||||
|
invalidInputError(low_bit == "fp16" or low_bit == "bf16",
|
||||||
|
"Self-Speculative only supports low_bit fp16 or bf16")
|
||||||
|
model_kwargs["speculative"] = True
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
tokenizer = tokenizer_cls.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = tokenizer_cls.from_pretrained(model_path, trust_remote_code=True)
|
||||||
model = model_cls.from_pretrained(model_path, **model_kwargs)
|
model = model_cls.from_pretrained(model_path, **model_kwargs)
|
||||||
|
|
|
||||||
|
|
@ -97,6 +97,7 @@ def generate(
|
||||||
new_speculative_kwargs[var] = value
|
new_speculative_kwargs[var] = value
|
||||||
return self.speculative_generate(inputs=inputs,
|
return self.speculative_generate(inputs=inputs,
|
||||||
draft_model=self.draft_model,
|
draft_model=self.draft_model,
|
||||||
|
streamer=streamer,
|
||||||
**new_speculative_kwargs)
|
**new_speculative_kwargs)
|
||||||
else:
|
else:
|
||||||
# When `draft_model` is false, these attributes
|
# When `draft_model` is false, these attributes
|
||||||
|
|
@ -512,7 +513,7 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
|
||||||
return past_key_values
|
return past_key_values
|
||||||
|
|
||||||
|
|
||||||
def _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs):
|
def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sampling_kwargs):
|
||||||
if generation_config is None:
|
if generation_config is None:
|
||||||
generation_config = self.generation_config
|
generation_config = self.generation_config
|
||||||
|
|
||||||
|
|
@ -591,8 +592,8 @@ def _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs):
|
||||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||||
|
|
||||||
# if streamer is not None:
|
if streamer is not None:
|
||||||
# streamer.put(input_ids.cpu())
|
streamer.put(input_ids.cpu())
|
||||||
|
|
||||||
input_ids_length = input_ids.shape[-1]
|
input_ids_length = input_ids.shape[-1]
|
||||||
|
|
||||||
|
|
@ -658,6 +659,7 @@ def speculative_generate(self,
|
||||||
min_step_draft=3,
|
min_step_draft=3,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**sampling_kwargs):
|
**sampling_kwargs):
|
||||||
invalidInputError(draft_model is not None,
|
invalidInputError(draft_model is not None,
|
||||||
"Draft model should be provided.")
|
"Draft model should be provided.")
|
||||||
|
|
@ -666,7 +668,7 @@ def speculative_generate(self,
|
||||||
min_step_draft = min_step_draft if min_step_draft >= 1 else 1
|
min_step_draft = min_step_draft if min_step_draft >= 1 else 1
|
||||||
|
|
||||||
input_ids, generation_config, logits_processor, stopping_criteria, \
|
input_ids, generation_config, logits_processor, stopping_criteria, \
|
||||||
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
|
model_kwargs = _prepare_generate_args(self, inputs, generation_config, streamer,
|
||||||
**sampling_kwargs)
|
**sampling_kwargs)
|
||||||
|
|
||||||
step = 0
|
step = 0
|
||||||
|
|
@ -1061,7 +1063,8 @@ def speculative_generate(self,
|
||||||
|
|
||||||
generate_ids[:, step:step+output_ids.size(1)] = output_ids
|
generate_ids[:, step:step+output_ids.size(1)] = output_ids
|
||||||
current_input_ids = output_ids[:, -1:]
|
current_input_ids = output_ids[:, -1:]
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.put(output_ids.cpu())
|
||||||
step += output_ids.size(1)
|
step += output_ids.size(1)
|
||||||
|
|
||||||
# remove one generated by the base model
|
# remove one generated by the base model
|
||||||
|
|
@ -1094,7 +1097,8 @@ def speculative_generate(self,
|
||||||
idx = output_ids_list.index(generation_config.eos_token_id)
|
idx = output_ids_list.index(generation_config.eos_token_id)
|
||||||
step -= (len(output_ids_list) - idx - 1)
|
step -= (len(output_ids_list) - idx - 1)
|
||||||
break
|
break
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.end()
|
||||||
step = min(step, max_new_tokens)
|
step = min(step, max_new_tokens)
|
||||||
e2e_toc = time.time()
|
e2e_toc = time.time()
|
||||||
self.n_token_generated = step
|
self.n_token_generated = step
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue