support Megrez-3B-Omni (#12582)
This commit is contained in:
parent
4e7e988f70
commit
3eeb02f1be
2 changed files with 32 additions and 2 deletions
|
|
@ -1055,6 +1055,12 @@ def _optimize_pre(model, qtype=None):
|
|||
model.llm.config.model_type = "minicpm"
|
||||
_optimize_pre(model.llm, qtype=qtype)
|
||||
model.llm.config.model_type = "minicpmv"
|
||||
elif model.config.model_type == "megrezo":
|
||||
from ipex_llm.transformers.models.minicpmv import merge_qkv
|
||||
model.vision.apply(merge_qkv)
|
||||
model.llm.config.model_type = "llama"
|
||||
_optimize_pre(model.llm, qtype=qtype)
|
||||
model.llm.config.model_type = "megrezo"
|
||||
elif model.config.model_type == "chatglm":
|
||||
if hasattr(model.config, 'padded_vocab_size') and model.config.padded_vocab_size == 65024:
|
||||
# chatglm2 and chatglm3
|
||||
|
|
@ -2202,5 +2208,29 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model.vpm, vpm_module.Idefics2VisionAttention, siglip_attention_forward)
|
||||
minicpmv_chat = minicpmv_chat_wrapper(module.MiniCPMV.chat)
|
||||
model.chat = MethodType(minicpmv_chat, model)
|
||||
elif model.config.model_type == "megrezo":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.minicpmv import minicpmv_generate_wrapper
|
||||
minicpmv_generate = minicpmv_generate_wrapper(module.MegrezO.generate)
|
||||
model.generate = MethodType(minicpmv_generate, model)
|
||||
|
||||
# vision
|
||||
vpm_modeling_module_name = model.vision.vpm.__class__.__module__
|
||||
vpm_module = importlib.import_module(vpm_modeling_module_name)
|
||||
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
||||
convert_forward(model.vision.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
|
||||
|
||||
# resampler
|
||||
from ipex_llm.transformers.models.minicpmv import _in_projection_packed
|
||||
resampler_module_name = model.vision.resampler.__class__.__module__
|
||||
resampler_module = importlib.import_module(resampler_module_name)
|
||||
resampler_module._in_projection_packed = _in_projection_packed
|
||||
|
||||
# llm
|
||||
model.llm.config.model_type = "llama"
|
||||
model.llm.config.rope_scaling = {"rope_type": "default"}
|
||||
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
|
||||
model.llm.config.model_type = "megrezo"
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -198,8 +198,8 @@ def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, de
|
|||
elif seq_length != kv_length and seq_length <= 32:
|
||||
mask = None
|
||||
else:
|
||||
mask = torch.zeros([1, 1, 1, padding_kv_length], torch.finfo(dtype).min,
|
||||
dtype=dtype, device=device)
|
||||
mask = torch.zeros([1, 1, 1, padding_kv_length], dtype=dtype, device=device)
|
||||
mask[:, :, kv_length:padding_kv_length] = torch.finfo(dtype).min
|
||||
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
|
||||
else:
|
||||
if seq_length != kv_length and seq_length <= 32:
|
||||
|
|
|
|||
Loading…
Reference in a new issue