LLM: Enable Speculative on Fastchat (#10909)
* init * enable streamer * update * update * remove deprecated * update * update * add gpu example
This commit is contained in:
		
							parent
							
								
									8379f02a74
								
							
						
					
					
						commit
						0e0bd309e2
					
				
					 5 changed files with 66 additions and 41 deletions
				
			
		| 
						 | 
				
			
			@ -61,6 +61,23 @@ export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
 | 
			
		|||
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --low-bit "sym_int4" --trust-remote-code --device "xpu"
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### For self-speculative decoding example:
 | 
			
		||||
 | 
			
		||||
You can use IPEX-LLM to run `self-speculative decoding` example. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel MAX GPUs. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel CPUs.
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
# Available low_bit format only including bf16 on CPU.
 | 
			
		||||
source ipex-llm-init -t
 | 
			
		||||
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "bf16" --trust-remote-code --device "cpu" --speculative
 | 
			
		||||
 | 
			
		||||
# Available low_bit format only including fp16 on GPU.
 | 
			
		||||
source /opt/intel/oneapi/setvars.sh
 | 
			
		||||
export ENABLE_SDP_FUSION=1
 | 
			
		||||
export SYCL_CACHE_PERSISTENT=1
 | 
			
		||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
 | 
			
		||||
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "fp16" --trust-remote-code --device "xpu" --speculative
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
You can get output like this:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,9 +25,11 @@ You may install **`ipex-llm`** with `FastChat` as follows:
 | 
			
		|||
 | 
			
		||||
```bash
 | 
			
		||||
pip install --pre --upgrade ipex-llm[serving]
 | 
			
		||||
pip install transformers==4.36.0
 | 
			
		||||
 | 
			
		||||
# Or
 | 
			
		||||
pip install --pre --upgrade ipex-llm[all]
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
To add GPU support for FastChat, you may install **`ipex-llm`** as follows:
 | 
			
		||||
| 
						 | 
				
			
			@ -51,39 +53,6 @@ python3 -m fastchat.serve.controller
 | 
			
		|||
 | 
			
		||||
Using IPEX-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat.
 | 
			
		||||
 | 
			
		||||
#### IPEX-LLM model worker (deprecated)
 | 
			
		||||
 | 
			
		||||
> Warning: This method has been deprecated, please change to use `IPEX-LLM` [worker](#ipex-llm-worker) instead.
 | 
			
		||||
 | 
			
		||||
FastChat determines the Model adapter to use through path matching. Therefore, in order to load models using IPEX-LLM, you need to make some modifications to the model's name.
 | 
			
		||||
 | 
			
		||||
For instance, assuming you have downloaded the `llama-7b-hf` from [HuggingFace](https://huggingface.co/decapoda-research/llama-7b-hf).  Then, to use the `IPEX-LLM` as backend, you need to change name from `llama-7b-hf` to `ipex-llm-7b`.The key point here is that the model's path should include "ipex" and **should not include paths matched by other model adapters**.
 | 
			
		||||
 | 
			
		||||
Then we will use `ipex-llm-7b` as model-path.
 | 
			
		||||
 | 
			
		||||
> note: This is caused by the priority of name matching list. The new added `IPEX-LLM` adapter is at the tail of the name-matching list so that it has the lowest priority. If model path contains other keywords like `vicuna` which matches to another adapter with higher priority, then the `IPEX-LLM` adapter will not work.
 | 
			
		||||
 | 
			
		||||
A special case is `ChatGLM` models. For these models, you do not need to do any changes after downloading the model and the `IPEX-LLM` backend will be used automatically.
 | 
			
		||||
 | 
			
		||||
Then we can run model workers
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
# On CPU
 | 
			
		||||
python3 -m ipex_llm.serving.fastchat.model_worker --model-path PATH/TO/ipex-llm-7b --device cpu
 | 
			
		||||
 | 
			
		||||
# On GPU
 | 
			
		||||
python3 -m ipex_llm.serving.fastchat.model_worker --model-path PATH/TO/ipex-llm-7b --device xpu
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
If you run successfully using `ipex_llm` backend, you can see the output in log like this:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
INFO - Converting the current model to sym_int4 format......
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
> note: We currently only support int4 quantization for this method.
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
#### IPEX-LLM worker
 | 
			
		||||
 | 
			
		||||
To integrate IPEX-LLM with `FastChat` efficiently, we have provided a new model_worker implementation named `ipex_llm_worker.py`.
 | 
			
		||||
| 
						 | 
				
			
			@ -104,6 +73,23 @@ For GPU example:
 | 
			
		|||
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "xpu"
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### For self-speculative decoding example:
 | 
			
		||||
 | 
			
		||||
You can use IPEX-LLM to run `self-speculative decoding` example. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel MAX GPUs. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel CPUs.
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
# Available low_bit format only including bf16 on CPU.
 | 
			
		||||
source ipex-llm-init -t
 | 
			
		||||
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "bf16" --trust-remote-code --device "cpu" --speculative
 | 
			
		||||
 | 
			
		||||
# Available low_bit format only including fp16 on GPU.
 | 
			
		||||
source /opt/intel/oneapi/setvars.sh
 | 
			
		||||
export ENABLE_SDP_FUSION=1
 | 
			
		||||
export SYCL_CACHE_PERSISTENT=1
 | 
			
		||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
 | 
			
		||||
python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "fp16" --trust-remote-code --device "xpu" --speculative
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
For a full list of accepted arguments, you can refer to the main method of the `ipex_llm_worker.py`
 | 
			
		||||
 | 
			
		||||
#### IPEX-LLM vLLM worker
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -63,6 +63,7 @@ class BigDLLLMWorker(BaseModelWorker):
 | 
			
		|||
        device: str = "cpu",
 | 
			
		||||
        no_register: bool = False,
 | 
			
		||||
        trust_remote_code: bool = False,
 | 
			
		||||
        speculative: bool = False,
 | 
			
		||||
        stream_interval: int = 4,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(
 | 
			
		||||
| 
						 | 
				
			
			@ -82,11 +83,13 @@ class BigDLLLMWorker(BaseModelWorker):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Using low bit format: {self.load_in_low_bit}, device: {device}")
 | 
			
		||||
        if speculative:
 | 
			
		||||
            logger.info(f"Using Self-Speculative decoding to generate")
 | 
			
		||||
 | 
			
		||||
        self.device = device
 | 
			
		||||
 | 
			
		||||
        self.speculative = speculative
 | 
			
		||||
        self.model, self.tokenizer = load_model(
 | 
			
		||||
            model_path, device, self.load_in_low_bit, trust_remote_code
 | 
			
		||||
            model_path, device, self.load_in_low_bit, trust_remote_code, speculative
 | 
			
		||||
        )
 | 
			
		||||
        self.stream_interval = stream_interval
 | 
			
		||||
        self.context_len = get_context_length(self.model.config)
 | 
			
		||||
| 
						 | 
				
			
			@ -98,6 +101,7 @@ class BigDLLLMWorker(BaseModelWorker):
 | 
			
		|||
        # context length is self.context_length
 | 
			
		||||
        prompt = params["prompt"]
 | 
			
		||||
        temperature = float(params.get("temperature", 1.0))
 | 
			
		||||
        do_sample = bool(params.get("do_sample", False))
 | 
			
		||||
        repetition_penalty = float(params.get("repetition_penalty", 1.0))
 | 
			
		||||
        top_p = float(params.get("top_p", 1.0))
 | 
			
		||||
        top_k = int(params.get("top_k", 1))
 | 
			
		||||
| 
						 | 
				
			
			@ -165,6 +169,7 @@ class BigDLLLMWorker(BaseModelWorker):
 | 
			
		|||
            streamer=streamer,
 | 
			
		||||
            temperature=temperature,
 | 
			
		||||
            repetition_penalty=repetition_penalty,
 | 
			
		||||
            do_sample=do_sample,
 | 
			
		||||
            top_p=top_p,
 | 
			
		||||
            top_k=top_k,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			@ -314,6 +319,12 @@ if __name__ == "__main__":
 | 
			
		|||
    parser.add_argument(
 | 
			
		||||
        "--device", type=str, default="cpu", help="Device for executing model, cpu/xpu"
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--speculative",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        default=False,
 | 
			
		||||
        help="To use self-speculative or not",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--trust-remote-code",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
| 
						 | 
				
			
			@ -335,5 +346,6 @@ if __name__ == "__main__":
 | 
			
		|||
        args.device,
 | 
			
		||||
        args.no_register,
 | 
			
		||||
        args.trust_remote_code,
 | 
			
		||||
        args.speculative,
 | 
			
		||||
    )
 | 
			
		||||
    uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -45,6 +45,7 @@ def load_model(
 | 
			
		|||
    device: str = "cpu",
 | 
			
		||||
    low_bit: str = 'sym_int4',
 | 
			
		||||
    trust_remote_code: bool = True,
 | 
			
		||||
    speculative: bool = False,
 | 
			
		||||
):
 | 
			
		||||
    """Load a model using BigDL LLM backend."""
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -64,6 +65,11 @@ def load_model(
 | 
			
		|||
    else:
 | 
			
		||||
        model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": 'auto'})
 | 
			
		||||
 | 
			
		||||
    if speculative:
 | 
			
		||||
        invalidInputError(low_bit == "fp16" or low_bit == "bf16",
 | 
			
		||||
                          "Self-Speculative only supports low_bit fp16 or bf16")
 | 
			
		||||
        model_kwargs["speculative"] = True
 | 
			
		||||
 | 
			
		||||
    # Load tokenizer
 | 
			
		||||
    tokenizer = tokenizer_cls.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
    model = model_cls.from_pretrained(model_path, **model_kwargs)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -97,6 +97,7 @@ def generate(
 | 
			
		|||
                new_speculative_kwargs[var] = value
 | 
			
		||||
        return self.speculative_generate(inputs=inputs,
 | 
			
		||||
                                         draft_model=self.draft_model,
 | 
			
		||||
                                         streamer=streamer,
 | 
			
		||||
                                         **new_speculative_kwargs)
 | 
			
		||||
    else:
 | 
			
		||||
        # When `draft_model` is false, these attributes
 | 
			
		||||
| 
						 | 
				
			
			@ -512,7 +513,7 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
 | 
			
		|||
    return past_key_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs):
 | 
			
		||||
def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sampling_kwargs):
 | 
			
		||||
    if generation_config is None:
 | 
			
		||||
        generation_config = self.generation_config
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -591,8 +592,8 @@ def _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs):
 | 
			
		|||
    # 5. Prepare `input_ids` which will be used for auto-regressive generation
 | 
			
		||||
    input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
 | 
			
		||||
 | 
			
		||||
    # if streamer is not None:
 | 
			
		||||
    #     streamer.put(input_ids.cpu())
 | 
			
		||||
    if streamer is not None:
 | 
			
		||||
        streamer.put(input_ids.cpu())
 | 
			
		||||
 | 
			
		||||
    input_ids_length = input_ids.shape[-1]
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -658,6 +659,7 @@ def speculative_generate(self,
 | 
			
		|||
                         min_step_draft=3,
 | 
			
		||||
                         generation_config: Optional[GenerationConfig] = None,
 | 
			
		||||
                         attention_mask=None,
 | 
			
		||||
                         streamer: Optional["BaseStreamer"] = None,
 | 
			
		||||
                         **sampling_kwargs):
 | 
			
		||||
    invalidInputError(draft_model is not None,
 | 
			
		||||
                      "Draft model should be provided.")
 | 
			
		||||
| 
						 | 
				
			
			@ -666,7 +668,7 @@ def speculative_generate(self,
 | 
			
		|||
    min_step_draft = min_step_draft if min_step_draft >= 1 else 1
 | 
			
		||||
 | 
			
		||||
    input_ids, generation_config, logits_processor, stopping_criteria, \
 | 
			
		||||
        model_kwargs = _prepare_generate_args(self, inputs, generation_config,
 | 
			
		||||
        model_kwargs = _prepare_generate_args(self, inputs, generation_config, streamer,
 | 
			
		||||
                                              **sampling_kwargs)
 | 
			
		||||
 | 
			
		||||
    step = 0
 | 
			
		||||
| 
						 | 
				
			
			@ -1061,7 +1063,8 @@ def speculative_generate(self,
 | 
			
		|||
 | 
			
		||||
            generate_ids[:, step:step+output_ids.size(1)] = output_ids
 | 
			
		||||
            current_input_ids = output_ids[:, -1:]
 | 
			
		||||
 | 
			
		||||
            if streamer is not None:
 | 
			
		||||
                streamer.put(output_ids.cpu())
 | 
			
		||||
            step += output_ids.size(1)
 | 
			
		||||
 | 
			
		||||
            # remove one generated by the base model
 | 
			
		||||
| 
						 | 
				
			
			@ -1094,7 +1097,8 @@ def speculative_generate(self,
 | 
			
		|||
            idx = output_ids_list.index(generation_config.eos_token_id)
 | 
			
		||||
            step -= (len(output_ids_list) - idx - 1)
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
    if streamer is not None:
 | 
			
		||||
        streamer.end()
 | 
			
		||||
    step = min(step, max_new_tokens)
 | 
			
		||||
    e2e_toc = time.time()
 | 
			
		||||
    self.n_token_generated = step
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue