LLM: Enable qwen target_model ipex (#10232)
* change order * enable qwen ipex * update qwen example * update * fix style * update
This commit is contained in:
		
							parent
							
								
									3e6d188553
								
							
						
					
					
						commit
						f9b75f900b
					
				
					 4 changed files with 74 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -89,4 +89,31 @@ assistant
 | 
			
		|||
Tokens generated 128
 | 
			
		||||
E2E Generation time x.xxxxs
 | 
			
		||||
First token latency x.xxxxs
 | 
			
		||||
```
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 4. Accelerate with BIGDL_OPT_IPEX
 | 
			
		||||
 | 
			
		||||
To accelerate speculative decoding on CPU, you can install our validated version of [IPEX 2.3.0+git004cd72d](https://github.com/intel/intel-extension-for-pytorch/tree/004cd72db60e87bb0712d42e3120bac9854bd77e) by following steps: (Other versions of IPEX may have some conflicts and can not accelerate speculative decoding correctly.)
 | 
			
		||||
 | 
			
		||||
#### 4.1 Download IPEX installation script
 | 
			
		||||
```bash
 | 
			
		||||
# Depend on Conda and GCC 12.3
 | 
			
		||||
wget https://raw.githubusercontent.com/intel/intel-extension-for-pytorch/004cd72db60e87bb0712d42e3120bac9854bd77e/scripts/compile_bundle.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 4.2 Activate your conda environment
 | 
			
		||||
```bash
 | 
			
		||||
conda activate <your_conda_env>
 | 
			
		||||
```
 | 
			
		||||
#### 4.3 Set VER_IPEX in compile_bundle.sh to 004cd72db60e87bb0712d42e3120bac9854bd77e
 | 
			
		||||
```bash
 | 
			
		||||
sed -i 's/VER_IPEX=main/VER_IPEX=004cd72db60e87bb0712d42e3120bac9854bd77e/g' "compile_bundle.sh"
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 4.4 Install IPEX and other dependencies
 | 
			
		||||
```bash
 | 
			
		||||
# Install IPEX 2.3.0+git004cd72d
 | 
			
		||||
bash compile_bundle.sh 
 | 
			
		||||
 | 
			
		||||
# Update transformers
 | 
			
		||||
pip install transformers==4.36.2
 | 
			
		||||
| 
						 | 
				
			
			@ -54,6 +54,8 @@ if __name__ == '__main__':
 | 
			
		|||
                        help='Max tokens to predict')
 | 
			
		||||
    parser.add_argument('--th_stop_draft', type=float, default=0.6,
 | 
			
		||||
                        help='draft stop probility')
 | 
			
		||||
    parser.add_argument('--min_step_draft', type=int, default=1,
 | 
			
		||||
                        help='min tokens per step draft')
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    model_path = args.repo_id_or_model_path
 | 
			
		||||
| 
						 | 
				
			
			@ -67,13 +69,15 @@ if __name__ == '__main__':
 | 
			
		|||
                                                 speculative=True,
 | 
			
		||||
                                                 trust_remote_code=True,
 | 
			
		||||
                                                 use_cache=True)
 | 
			
		||||
    model = model.to('cpu')
 | 
			
		||||
    #model = model.to('cpu')
 | 
			
		||||
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
    with torch.inference_mode():
 | 
			
		||||
        prompt = QWEN_PROMPT_FORMAT.format(prompt=args.prompt)
 | 
			
		||||
        input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
 | 
			
		||||
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 | 
			
		||||
        input_ids = inputs.input_ids
 | 
			
		||||
        attention_mask = inputs.attention_mask.to(model.device)
 | 
			
		||||
        actual_in_len = input_ids.shape[1]
 | 
			
		||||
        print("actual input_ids length:" + str(actual_in_len))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -81,6 +85,8 @@ if __name__ == '__main__':
 | 
			
		|||
        output = model.generate(input_ids,
 | 
			
		||||
                                max_new_tokens=args.n_predict,
 | 
			
		||||
                                th_stop_draft=args.th_stop_draft,
 | 
			
		||||
                                attention_mask=attention_mask,
 | 
			
		||||
                                min_step_draft=args.min_step_draft,
 | 
			
		||||
                                do_sample=False)
 | 
			
		||||
        output_str = tokenizer.decode(output[0])
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -89,6 +95,8 @@ if __name__ == '__main__':
 | 
			
		|||
        output = model.generate(input_ids,
 | 
			
		||||
                                max_new_tokens=args.n_predict,
 | 
			
		||||
                                th_stop_draft=args.th_stop_draft,
 | 
			
		||||
                                attention_mask=attention_mask,
 | 
			
		||||
                                min_step_draft=args.min_step_draft,
 | 
			
		||||
                                do_sample=False)
 | 
			
		||||
        output_str = tokenizer.decode(output[0], skip_special_tokens=True)
 | 
			
		||||
        end = time.perf_counter()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -93,12 +93,10 @@ def _ipex_optimize_attention(model):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def _ipex_optimize_model(model, rms_classes):
 | 
			
		||||
    from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
 | 
			
		||||
 | 
			
		||||
    _ipex_optimize_rmsnorm(model, rms_classes)
 | 
			
		||||
    _ipex_optimize_attention(model)
 | 
			
		||||
    _ipex_optimize_decoder(model)
 | 
			
		||||
    model.register_forward_hook(output_hook, with_kwargs=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _ipex_jit(model):
 | 
			
		||||
| 
						 | 
				
			
			@ -124,6 +122,8 @@ def _ipex_jit(model):
 | 
			
		|||
        model = _set_optimized_model_for_generation(
 | 
			
		||||
            model, optimized_model=trace_model
 | 
			
		||||
        )
 | 
			
		||||
    from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
 | 
			
		||||
    model.register_forward_hook(output_hook, with_kwargs=True)
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -153,6 +153,12 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
 | 
			
		|||
                                                len0, len3).permute(2, 0, 1, 3)
 | 
			
		||||
                list = [key[:cur_len, :, :, :], value[:cur_len, :, :, :]]
 | 
			
		||||
                ipex_past_key_values.append(list)
 | 
			
		||||
        elif self.config.model_type == "qwen":
 | 
			
		||||
            ipex_past_key_values = [
 | 
			
		||||
                [pkv[1].permute(1, 0, 2, 3)[:, :cur_len, :, :],
 | 
			
		||||
                    pkv[2].permute(1, 0, 2, 3)[:, :cur_len, :, :]]
 | 
			
		||||
                for pkv in past_key_values
 | 
			
		||||
            ]
 | 
			
		||||
        else:
 | 
			
		||||
            ipex_past_key_values = [
 | 
			
		||||
                [pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
 | 
			
		||||
| 
						 | 
				
			
			@ -217,6 +223,18 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
 | 
			
		|||
                    torch.float32)
 | 
			
		||||
                past_key_values_storage[i][1][:len2, :, :, :] = ipex_past_key_values[i][1].to(
 | 
			
		||||
                    torch.float32)
 | 
			
		||||
            elif self.config.model_type == "qwen":
 | 
			
		||||
                k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
 | 
			
		||||
                                dtype=torch.float32)
 | 
			
		||||
                v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
 | 
			
		||||
                                dtype=torch.float32)
 | 
			
		||||
                k0 = k0.permute(0, 2, 1, 3)
 | 
			
		||||
                v0 = v0.permute(0, 2, 1, 3)
 | 
			
		||||
                past_key_values_storage.append((k0, v0))
 | 
			
		||||
                past_key_values_storage[i][0][:, :len2, :, :] = ipex_past_key_values[i][0].to(
 | 
			
		||||
                    torch.float32)
 | 
			
		||||
                past_key_values_storage[i][1][:, :len2, :, :] = ipex_past_key_values[i][1].to(
 | 
			
		||||
                    torch.float32)
 | 
			
		||||
            else:
 | 
			
		||||
                k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
 | 
			
		||||
                                dtype=torch.float32)
 | 
			
		||||
| 
						 | 
				
			
			@ -309,6 +327,16 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
 | 
			
		|||
                    key.to(torch.float32)
 | 
			
		||||
                past_key_values_storage[i][1][size:size1, :, :, :] = \
 | 
			
		||||
                    value.to(torch.float32)
 | 
			
		||||
            elif self.config.model_type == "qwen":
 | 
			
		||||
                size = original_draft_past_key_values[0][0].size(1)
 | 
			
		||||
                delta_past_key = \
 | 
			
		||||
                    past_key_values[i][1][size:size1, :, :, :].permute(1, 0, 2, 3)
 | 
			
		||||
                delta_past_value = \
 | 
			
		||||
                    past_key_values[i][2][size:size1, :, :, :].permute(1, 0, 2, 3)
 | 
			
		||||
                past_key_values_storage[i][0][:, size:size1, :, :] = \
 | 
			
		||||
                    delta_past_key.to(torch.float32)
 | 
			
		||||
                past_key_values_storage[i][1][:, size:size1, :, :] = \
 | 
			
		||||
                    delta_past_value.to(torch.float32)
 | 
			
		||||
            else:
 | 
			
		||||
                delta_past_key = \
 | 
			
		||||
                    past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
 | 
			
		||||
| 
						 | 
				
			
			@ -444,9 +472,10 @@ def speculative_generate(self,
 | 
			
		|||
        if not ((self.config.model_type == 'baichuan') or
 | 
			
		||||
                ('llama' in self.config.model_type) or
 | 
			
		||||
                ("mistral" in self.config.model_type) or
 | 
			
		||||
                ("qwen" in self.config.model_type) or
 | 
			
		||||
                ("chatglm" in self.config.model_type)):
 | 
			
		||||
            invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
 | 
			
		||||
                                      Llama, Baichuan2, Mistral and ChatGLM models currently.")
 | 
			
		||||
                                      Llama, Baichuan2, Mistral and ChatGLM and Qwen models currently.")
 | 
			
		||||
        if "chatglm" in self.config.model_type:
 | 
			
		||||
            global query_group_size
 | 
			
		||||
            query_group_size = draft_model.config.num_attention_heads // \
 | 
			
		||||
| 
						 | 
				
			
			@ -637,6 +666,10 @@ def speculative_generate(self,
 | 
			
		|||
                                              position_ids=position_ids,
 | 
			
		||||
                                              # return_last_logit=torch.tensor(False),
 | 
			
		||||
                                              past_key_values=past_key_values,)
 | 
			
		||||
                elif "qwen" in self.config.model_type:
 | 
			
		||||
                    output = self.trace_graph(input_ids=drafted_input_ids,
 | 
			
		||||
                                              attention_mask=cur_attention_mask,
 | 
			
		||||
                                              past_key_values=past_key_values)
 | 
			
		||||
                elif "mistral" in self.config.model_type:
 | 
			
		||||
                    past_key_value_len = past_key_values[0][0].shape[2]
 | 
			
		||||
                    seq_len = drafted_input_ids.shape[1]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue