optimize minicpm v 2_6 firs token perf (#11770)
This commit is contained in:
parent
841dbcdf3a
commit
a1eb793f70
3 changed files with 47 additions and 1 deletions
|
|
@ -748,6 +748,8 @@ def _optimize_pre(model, qtype=None):
|
|||
from ipex_llm.transformers.models.llama import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
if model.config.model_type == "minicpmv":
|
||||
from ipex_llm.transformers.models.minicpmv import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
||||
model.llm.config.model_type = "qwen2"
|
||||
_optimize_pre(model.llm, qtype=qtype)
|
||||
|
|
@ -1763,4 +1765,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
|
||||
model.generate = MethodType(minicpmv_generate, model)
|
||||
|
||||
modeling_module_name = model.vpm.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
||||
convert_forward(model, module.SiglipAttention, siglip_attention_forward)
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -37,7 +37,10 @@ def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
|
|||
|
||||
|
||||
def merge_qkv_base(module: torch.nn.Module, attention_class):
|
||||
if isinstance(module, attention_class):
|
||||
if (
|
||||
isinstance(attention_class, str) and module.__class__.__name__ == attention_class
|
||||
or not isinstance(attention_class, str) and isinstance(module, attention_class)
|
||||
):
|
||||
qkv_proj = merge_linear([
|
||||
module.q_proj,
|
||||
module.k_proj,
|
||||
|
|
|
|||
|
|
@ -16,9 +16,45 @@
|
|||
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
||||
|
||||
|
||||
def merge_qkv(module: torch.nn.Module):
|
||||
return merge_qkv_base(module, "SiglipAttention")
|
||||
|
||||
|
||||
def siglip_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
|
||||
qkv = qkv.transpose(1, 2)
|
||||
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
||||
|
||||
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||
if scores.device.type == "xpu":
|
||||
import xe_addons
|
||||
|
|
|
|||
Loading…
Reference in a new issue