fix and optimize minicpm v 2 (#11799)
This commit is contained in:
parent
d8d887edd2
commit
9a93808fc5
2 changed files with 45 additions and 12 deletions
|
|
@ -1726,6 +1726,11 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
|
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
|
||||||
model.generate = MethodType(minicpmv_generate, model)
|
model.generate = MethodType(minicpmv_generate, model)
|
||||||
|
|
||||||
|
if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
|
||||||
|
# MiniCPM-V 2
|
||||||
|
model.llm.config.model_type = "minicpm"
|
||||||
|
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
|
||||||
|
model.llm.config.model_type = "minicpmv"
|
||||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
||||||
# MiniCPM-V 2.6
|
# MiniCPM-V 2.6
|
||||||
model.llm.config.model_type = "qwen2"
|
model.llm.config.model_type = "qwen2"
|
||||||
|
|
@ -1739,7 +1744,11 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
|
|
||||||
vpm_modeling_module_name = model.vpm.__class__.__module__
|
vpm_modeling_module_name = model.vpm.__class__.__module__
|
||||||
vpm_module = importlib.import_module(vpm_modeling_module_name)
|
vpm_module = importlib.import_module(vpm_modeling_module_name)
|
||||||
if model.vpm.config.model_type == "siglip":
|
if not hasattr(model.vpm, "config"):
|
||||||
|
# MiniCPM-V 2
|
||||||
|
from ipex_llm.transformers.models.minicpmv import minicpmv_get_vision_embedding
|
||||||
|
model.get_vision_embedding = MethodType(minicpmv_get_vision_embedding, model)
|
||||||
|
elif model.vpm.config.model_type == "siglip":
|
||||||
# MiniCPM-V 2.6
|
# MiniCPM-V 2.6
|
||||||
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
||||||
convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
|
convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||||
|
|
@ -22,11 +23,13 @@ from transformers import AutoProcessor
|
||||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
|
# MiniCPM-V-2_5 and MiniCPM-V-2_6
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
merge_qkv_base(module, "SiglipAttention")
|
merge_qkv_base(module, "SiglipAttention")
|
||||||
merge_qkv_base(module, "Idefics2VisionAttention")
|
merge_qkv_base(module, "Idefics2VisionAttention")
|
||||||
|
|
||||||
|
|
||||||
|
# MiniCPM-V-2_5 and MiniCPM-V-2_6
|
||||||
def siglip_attention_forward(
|
def siglip_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
@ -58,17 +61,7 @@ def siglip_attention_forward(
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
# MiniCPM-V-2_5
|
||||||
if scores.device.type == "xpu":
|
|
||||||
import xe_addons
|
|
||||||
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
|
|
||||||
else:
|
|
||||||
score = torch.gather(scores, 1, input_ids)
|
|
||||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
|
||||||
scores.scatter_(1, input_ids, score)
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
def minicpmv_chat_wrapper(origin_chat):
|
def minicpmv_chat_wrapper(origin_chat):
|
||||||
def minicpmv_chat(
|
def minicpmv_chat(
|
||||||
self,
|
self,
|
||||||
|
|
@ -106,6 +99,37 @@ def minicpmv_chat_wrapper(origin_chat):
|
||||||
return minicpmv_chat
|
return minicpmv_chat
|
||||||
|
|
||||||
|
|
||||||
|
# MiniCPM-V-2
|
||||||
|
def minicpmv_get_vision_embedding(self, pixel_values):
|
||||||
|
res = []
|
||||||
|
dtype = self.dtype
|
||||||
|
|
||||||
|
def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
|
||||||
|
H, W = pixel_value.shape[-2:]
|
||||||
|
target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
|
||||||
|
vision_embedding = self.vpm_forward_features(pixel_value.unsqueeze(0).type(dtype))
|
||||||
|
|
||||||
|
if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
|
||||||
|
vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
|
||||||
|
return resampler(vision_embedding, target_size)
|
||||||
|
|
||||||
|
for pixel_value in pixel_values:
|
||||||
|
result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
|
||||||
|
res.append(result)
|
||||||
|
return torch.vstack(res)
|
||||||
|
|
||||||
|
|
||||||
|
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||||
|
if scores.device.type == "xpu":
|
||||||
|
import xe_addons
|
||||||
|
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
|
||||||
|
else:
|
||||||
|
score = torch.gather(scores, 1, input_ids)
|
||||||
|
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||||
|
scores.scatter_(1, input_ids, score)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def minicpmv_generate_wrapper(origin_generate):
|
def minicpmv_generate_wrapper(origin_generate):
|
||||||
def generate(
|
def generate(
|
||||||
*inputs,
|
*inputs,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue