optimize minicpm v 2_6 firs token perf (#11770)

This commit is contained in:
Yishuo Wang 2024-08-13 09:51:18 +08:00 committed by GitHub
parent 841dbcdf3a
commit a1eb793f70
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 47 additions and 1 deletions

View file

@ -748,6 +748,8 @@ def _optimize_pre(model, qtype=None):
from ipex_llm.transformers.models.llama import merge_qkv from ipex_llm.transformers.models.llama import merge_qkv
model.apply(merge_qkv) model.apply(merge_qkv)
if model.config.model_type == "minicpmv": if model.config.model_type == "minicpmv":
from ipex_llm.transformers.models.minicpmv import merge_qkv
model.apply(merge_qkv)
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666: if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
model.llm.config.model_type = "qwen2" model.llm.config.model_type = "qwen2"
_optimize_pre(model.llm, qtype=qtype) _optimize_pre(model.llm, qtype=qtype)
@ -1763,4 +1765,9 @@ def _optimize_post(model, lightweight_bmm=False):
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate) minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
model.generate = MethodType(minicpmv_generate, model) model.generate = MethodType(minicpmv_generate, model)
modeling_module_name = model.vpm.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
convert_forward(model, module.SiglipAttention, siglip_attention_forward)
return model return model

View file

@ -37,7 +37,10 @@ def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
def merge_qkv_base(module: torch.nn.Module, attention_class): def merge_qkv_base(module: torch.nn.Module, attention_class):
if isinstance(module, attention_class): if (
isinstance(attention_class, str) and module.__class__.__name__ == attention_class
or not isinstance(attention_class, str) and isinstance(module, attention_class)
):
qkv_proj = merge_linear([ qkv_proj = merge_linear([
module.q_proj, module.q_proj,
module.k_proj, module.k_proj,

View file

@ -16,9 +16,45 @@
import torch import torch
from typing import Optional
from ipex_llm.transformers.models.common import merge_qkv_base
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
def merge_qkv(module: torch.nn.Module):
return merge_qkv_base(module, "SiglipAttention")
def siglip_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
):
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.chunk(3, dim=1)
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
if scores.device.type == "xpu": if scores.device.type == "xpu":
import xe_addons import xe_addons