support Megrez-3B-Omni (#12582)

This commit is contained in:
Yishuo Wang 2024-12-19 17:23:01 +08:00 committed by GitHub
parent 4e7e988f70
commit 3eeb02f1be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 32 additions and 2 deletions

View file

@ -1055,6 +1055,12 @@ def _optimize_pre(model, qtype=None):
model.llm.config.model_type = "minicpm"
_optimize_pre(model.llm, qtype=qtype)
model.llm.config.model_type = "minicpmv"
elif model.config.model_type == "megrezo":
from ipex_llm.transformers.models.minicpmv import merge_qkv
model.vision.apply(merge_qkv)
model.llm.config.model_type = "llama"
_optimize_pre(model.llm, qtype=qtype)
model.llm.config.model_type = "megrezo"
elif model.config.model_type == "chatglm":
if hasattr(model.config, 'padded_vocab_size') and model.config.padded_vocab_size == 65024:
# chatglm2 and chatglm3
@ -2202,5 +2208,29 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model.vpm, vpm_module.Idefics2VisionAttention, siglip_attention_forward)
minicpmv_chat = minicpmv_chat_wrapper(module.MiniCPMV.chat)
model.chat = MethodType(minicpmv_chat, model)
elif model.config.model_type == "megrezo":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.minicpmv import minicpmv_generate_wrapper
minicpmv_generate = minicpmv_generate_wrapper(module.MegrezO.generate)
model.generate = MethodType(minicpmv_generate, model)
# vision
vpm_modeling_module_name = model.vision.vpm.__class__.__module__
vpm_module = importlib.import_module(vpm_modeling_module_name)
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
convert_forward(model.vision.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
# resampler
from ipex_llm.transformers.models.minicpmv import _in_projection_packed
resampler_module_name = model.vision.resampler.__class__.__module__
resampler_module = importlib.import_module(resampler_module_name)
resampler_module._in_projection_packed = _in_projection_packed
# llm
model.llm.config.model_type = "llama"
model.llm.config.rope_scaling = {"rope_type": "default"}
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
model.llm.config.model_type = "megrezo"
return model

View file

@ -198,8 +198,8 @@ def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, de
elif seq_length != kv_length and seq_length <= 32:
mask = None
else:
mask = torch.zeros([1, 1, 1, padding_kv_length], torch.finfo(dtype).min,
dtype=dtype, device=device)
mask = torch.zeros([1, 1, 1, padding_kv_length], dtype=dtype, device=device)
mask[:, :, kv_length:padding_kv_length] = torch.finfo(dtype).min
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
else:
if seq_length != kv_length and seq_length <= 32: