fix minicpm-v-2 fp16 (#11819)
This commit is contained in:
parent
6543321f04
commit
750d4ad5dc
2 changed files with 25 additions and 0 deletions
|
|
@ -1735,7 +1735,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
vpm_module = importlib.import_module(vpm_modeling_module_name)
|
||||
if not hasattr(model.vpm, "config"):
|
||||
# MiniCPM-V 2
|
||||
from ipex_llm.transformers.models.minicpmv import vision_transformer_attention_forward
|
||||
from ipex_llm.transformers.models.minicpmv import minicpmv_get_vision_embedding
|
||||
convert_forward(model.vpm, vpm_module.Attention, vision_transformer_attention_forward)
|
||||
model.get_vision_embedding = MethodType(minicpmv_get_vision_embedding, model)
|
||||
elif model.vpm.config.model_type == "siglip":
|
||||
# MiniCPM-V 2.6
|
||||
|
|
|
|||
|
|
@ -61,6 +61,29 @@ def siglip_attention_forward(
|
|||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# MiniCPM-V-2
|
||||
# modified from timm.models.vision_transformer.Attention.forward
|
||||
def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bsz, q_len, hidden_size = x.size()
|
||||
|
||||
qkv = self.qkv(x)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
|
||||
qkv = qkv.transpose(1, 2)
|
||||
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
||||
|
||||
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = self.attn_drop(attn_weights)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
|
||||
|
||||
attn_output = self.proj(attn_output)
|
||||
attn_output = self.proj_drop(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
# MiniCPM-V-2_5
|
||||
def minicpmv_chat_wrapper(origin_chat):
|
||||
def minicpmv_chat(
|
||||
|
|
|
|||
Loading…
Reference in a new issue