fix minicpm-v-2 fp16 (#11819)
This commit is contained in:
parent
6543321f04
commit
750d4ad5dc
2 changed files with 25 additions and 0 deletions
|
|
@ -1735,7 +1735,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
vpm_module = importlib.import_module(vpm_modeling_module_name)
|
vpm_module = importlib.import_module(vpm_modeling_module_name)
|
||||||
if not hasattr(model.vpm, "config"):
|
if not hasattr(model.vpm, "config"):
|
||||||
# MiniCPM-V 2
|
# MiniCPM-V 2
|
||||||
|
from ipex_llm.transformers.models.minicpmv import vision_transformer_attention_forward
|
||||||
from ipex_llm.transformers.models.minicpmv import minicpmv_get_vision_embedding
|
from ipex_llm.transformers.models.minicpmv import minicpmv_get_vision_embedding
|
||||||
|
convert_forward(model.vpm, vpm_module.Attention, vision_transformer_attention_forward)
|
||||||
model.get_vision_embedding = MethodType(minicpmv_get_vision_embedding, model)
|
model.get_vision_embedding = MethodType(minicpmv_get_vision_embedding, model)
|
||||||
elif model.vpm.config.model_type == "siglip":
|
elif model.vpm.config.model_type == "siglip":
|
||||||
# MiniCPM-V 2.6
|
# MiniCPM-V 2.6
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,29 @@ def siglip_attention_forward(
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
# MiniCPM-V-2
|
||||||
|
# modified from timm.models.vision_transformer.Attention.forward
|
||||||
|
def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
bsz, q_len, hidden_size = x.size()
|
||||||
|
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
|
||||||
|
qkv = qkv.transpose(1, 2)
|
||||||
|
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
||||||
|
attn_weights = attention_softmax(attn_weights, self.training)
|
||||||
|
attn_weights = self.attn_drop(attn_weights)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.proj(attn_output)
|
||||||
|
attn_output = self.proj_drop(attn_output)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
# MiniCPM-V-2_5
|
# MiniCPM-V-2_5
|
||||||
def minicpmv_chat_wrapper(origin_chat):
|
def minicpmv_chat_wrapper(origin_chat):
|
||||||
def minicpmv_chat(
|
def minicpmv_chat(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue