optimize minicpm v 2_6 firs token perf (#11770)
This commit is contained in:
		
							parent
							
								
									841dbcdf3a
								
							
						
					
					
						commit
						a1eb793f70
					
				
					 3 changed files with 47 additions and 1 deletions
				
			
		| 
						 | 
					@ -748,6 +748,8 @@ def _optimize_pre(model, qtype=None):
 | 
				
			||||||
        from ipex_llm.transformers.models.llama import merge_qkv
 | 
					        from ipex_llm.transformers.models.llama import merge_qkv
 | 
				
			||||||
        model.apply(merge_qkv)
 | 
					        model.apply(merge_qkv)
 | 
				
			||||||
    if model.config.model_type == "minicpmv":
 | 
					    if model.config.model_type == "minicpmv":
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.minicpmv import merge_qkv
 | 
				
			||||||
 | 
					        model.apply(merge_qkv)
 | 
				
			||||||
        if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
 | 
					        if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
 | 
				
			||||||
            model.llm.config.model_type = "qwen2"
 | 
					            model.llm.config.model_type = "qwen2"
 | 
				
			||||||
            _optimize_pre(model.llm, qtype=qtype)
 | 
					            _optimize_pre(model.llm, qtype=qtype)
 | 
				
			||||||
| 
						 | 
					@ -1763,4 +1765,9 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
 | 
					        minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
 | 
				
			||||||
        model.generate = MethodType(minicpmv_generate, model)
 | 
					        model.generate = MethodType(minicpmv_generate, model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        modeling_module_name = model.vpm.__class__.__module__
 | 
				
			||||||
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
 | 
				
			||||||
 | 
					        convert_forward(model, module.SiglipAttention, siglip_attention_forward)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,7 +37,10 @@ def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def merge_qkv_base(module: torch.nn.Module, attention_class):
 | 
					def merge_qkv_base(module: torch.nn.Module, attention_class):
 | 
				
			||||||
    if isinstance(module, attention_class):
 | 
					    if (
 | 
				
			||||||
 | 
					        isinstance(attention_class, str) and module.__class__.__name__ == attention_class
 | 
				
			||||||
 | 
					        or not isinstance(attention_class, str) and isinstance(module, attention_class)
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
        qkv_proj = merge_linear([
 | 
					        qkv_proj = merge_linear([
 | 
				
			||||||
            module.q_proj,
 | 
					            module.q_proj,
 | 
				
			||||||
            module.k_proj,
 | 
					            module.k_proj,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,9 +16,45 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.common import merge_qkv_base
 | 
				
			||||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
 | 
					from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def merge_qkv(module: torch.nn.Module):
 | 
				
			||||||
 | 
					    return merge_qkv_base(module, "SiglipAttention")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def siglip_attention_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = False,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    qkv = self.qkv_proj(hidden_states)
 | 
				
			||||||
 | 
					    qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
 | 
				
			||||||
 | 
					    qkv = qkv.transpose(1, 2)
 | 
				
			||||||
 | 
					    query_states, key_states, value_states = qkv.chunk(3, dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
 | 
				
			||||||
 | 
					    if attention_mask is not None:
 | 
				
			||||||
 | 
					        attn_weights = attn_weights + attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # upcast attention to fp32
 | 
				
			||||||
 | 
					    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
 | 
				
			||||||
 | 
					    attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
 | 
				
			||||||
 | 
					    attn_output = torch.matmul(attn_weights, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
				
			||||||
 | 
					    attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = self.out_proj(attn_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return attn_output, attn_weights
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
 | 
					def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
 | 
				
			||||||
    if scores.device.type == "xpu":
 | 
					    if scores.device.type == "xpu":
 | 
				
			||||||
        import xe_addons
 | 
					        import xe_addons
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue