From a1eb793f70e2954483d9ded8e2606ac52bc0de91 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 13 Aug 2024 09:51:18 +0800 Subject: [PATCH] optimize minicpm v 2_6 firs token perf (#11770) --- .../llm/src/ipex_llm/transformers/convert.py | 7 ++++ .../ipex_llm/transformers/models/common.py | 5 ++- .../ipex_llm/transformers/models/minicpmv.py | 36 +++++++++++++++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 33c6b83d..7edb2b7a 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -748,6 +748,8 @@ def _optimize_pre(model, qtype=None): from ipex_llm.transformers.models.llama import merge_qkv model.apply(merge_qkv) if model.config.model_type == "minicpmv": + from ipex_llm.transformers.models.minicpmv import merge_qkv + model.apply(merge_qkv) if model.config.hidden_size == 3584 and model.config.vocab_size == 151666: model.llm.config.model_type = "qwen2" _optimize_pre(model.llm, qtype=qtype) @@ -1763,4 +1765,9 @@ def _optimize_post(model, lightweight_bmm=False): minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate) model.generate = MethodType(minicpmv_generate, model) + modeling_module_name = model.vpm.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from ipex_llm.transformers.models.minicpmv import siglip_attention_forward + convert_forward(model, module.SiglipAttention, siglip_attention_forward) + return model diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 86b0d46b..f3dab652 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -37,7 +37,10 @@ def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear: def merge_qkv_base(module: torch.nn.Module, attention_class): - if isinstance(module, attention_class): + if ( + isinstance(attention_class, str) and module.__class__.__name__ == attention_class + or not isinstance(attention_class, str) and isinstance(module, attention_class) + ): qkv_proj = merge_linear([ module.q_proj, module.k_proj, diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index ebde9407..c92226a0 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -16,9 +16,45 @@ import torch +from typing import Optional +from ipex_llm.transformers.models.common import merge_qkv_base from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor +def merge_qkv(module: torch.nn.Module): + return merge_qkv_base(module, "SiglipAttention") + + +def siglip_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, +): + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.chunk(3, dim=1) + + attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3)) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): if scores.device.type == "xpu": import xe_addons