optimize minicpm v 2_6 firs token perf (#11770)
This commit is contained in:
parent
841dbcdf3a
commit
a1eb793f70
3 changed files with 47 additions and 1 deletions
|
|
@ -748,6 +748,8 @@ def _optimize_pre(model, qtype=None):
|
||||||
from ipex_llm.transformers.models.llama import merge_qkv
|
from ipex_llm.transformers.models.llama import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
if model.config.model_type == "minicpmv":
|
if model.config.model_type == "minicpmv":
|
||||||
|
from ipex_llm.transformers.models.minicpmv import merge_qkv
|
||||||
|
model.apply(merge_qkv)
|
||||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
||||||
model.llm.config.model_type = "qwen2"
|
model.llm.config.model_type = "qwen2"
|
||||||
_optimize_pre(model.llm, qtype=qtype)
|
_optimize_pre(model.llm, qtype=qtype)
|
||||||
|
|
@ -1763,4 +1765,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
|
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
|
||||||
model.generate = MethodType(minicpmv_generate, model)
|
model.generate = MethodType(minicpmv_generate, model)
|
||||||
|
|
||||||
|
modeling_module_name = model.vpm.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
||||||
|
convert_forward(model, module.SiglipAttention, siglip_attention_forward)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,10 @@ def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv_base(module: torch.nn.Module, attention_class):
|
def merge_qkv_base(module: torch.nn.Module, attention_class):
|
||||||
if isinstance(module, attention_class):
|
if (
|
||||||
|
isinstance(attention_class, str) and module.__class__.__name__ == attention_class
|
||||||
|
or not isinstance(attention_class, str) and isinstance(module, attention_class)
|
||||||
|
):
|
||||||
qkv_proj = merge_linear([
|
qkv_proj = merge_linear([
|
||||||
module.q_proj,
|
module.q_proj,
|
||||||
module.k_proj,
|
module.k_proj,
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,45 @@
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from typing import Optional
|
||||||
|
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def merge_qkv(module: torch.nn.Module):
|
||||||
|
return merge_qkv_base(module, "SiglipAttention")
|
||||||
|
|
||||||
|
|
||||||
|
def siglip_attention_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
qkv = self.qkv_proj(hidden_states)
|
||||||
|
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
|
||||||
|
qkv = qkv.transpose(1, 2)
|
||||||
|
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||||
if scores.device.type == "xpu":
|
if scores.device.type == "xpu":
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue