fix minicpm-v 2.5 (#11780)
This commit is contained in:
parent
ec184af243
commit
a184b120c9
2 changed files with 21 additions and 31 deletions
|
|
@ -1755,19 +1755,29 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module.MiniCPMModel,
|
module.MiniCPMModel,
|
||||||
minicpm_model_forward)
|
minicpm_model_forward)
|
||||||
elif model.config.model_type == "minicpmv":
|
elif model.config.model_type == "minicpmv":
|
||||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
|
||||||
model.llm.config.model_type = "qwen2"
|
|
||||||
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
|
|
||||||
model.llm.config.model_type = "minicpmv"
|
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.minicpmv import minicpmv_generate_wrapper
|
from ipex_llm.transformers.models.minicpmv import minicpmv_generate_wrapper
|
||||||
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
|
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
|
||||||
model.generate = MethodType(minicpmv_generate, model)
|
model.generate = MethodType(minicpmv_generate, model)
|
||||||
|
|
||||||
modeling_module_name = model.vpm.__class__.__module__
|
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
||||||
module = importlib.import_module(modeling_module_name)
|
# MiniCPM-V 2.6
|
||||||
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
model.llm.config.model_type = "qwen2"
|
||||||
convert_forward(model, module.SiglipAttention, siglip_attention_forward)
|
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
|
||||||
|
model.llm.config.model_type = "minicpmv"
|
||||||
|
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
|
||||||
|
# MiniCPM-V 2.5
|
||||||
|
pass
|
||||||
|
|
||||||
|
vpm_modeling_module_name = model.vpm.__class__.__module__
|
||||||
|
vpm_module = importlib.import_module(vpm_modeling_module_name)
|
||||||
|
if model.vpm.config.model_type == "siglip":
|
||||||
|
# MiniCPM-V 2.6
|
||||||
|
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
||||||
|
convert_forward(model, vpm_module.SiglipAttention, siglip_attention_forward)
|
||||||
|
elif model.vpm.config.model_type == "idefics2":
|
||||||
|
# MiniCPM-V 2.5
|
||||||
|
pass
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
|
|
@ -69,32 +69,12 @@ def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: t
|
||||||
|
|
||||||
def minicpmv_generate_wrapper(origin_generate):
|
def minicpmv_generate_wrapper(origin_generate):
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
*inputs,
|
||||||
input_ids=None,
|
|
||||||
pixel_values=None,
|
|
||||||
tgt_sizes=None,
|
|
||||||
image_bound=None,
|
|
||||||
attention_mask=None,
|
|
||||||
tokenizer=None,
|
|
||||||
vision_hidden_states=None,
|
|
||||||
return_vision_hidden_states=False,
|
|
||||||
stream=False,
|
|
||||||
decode_text=False,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call
|
RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call
|
||||||
return origin_generate(
|
return origin_generate(
|
||||||
self=self,
|
*inputs,
|
||||||
input_ids=input_ids,
|
**kwargs,
|
||||||
pixel_values=pixel_values,
|
|
||||||
tgt_sizes=tgt_sizes,
|
|
||||||
image_bound=image_bound,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
vision_hidden_states=vision_hidden_states,
|
|
||||||
return_vision_hidden_states=return_vision_hidden_states,
|
|
||||||
stream=stream,
|
|
||||||
decode_text=decode_text,
|
|
||||||
**kwargs
|
|
||||||
)
|
)
|
||||||
return generate
|
return generate
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue