support Megrez-3B-Omni (#12582)
This commit is contained in:
parent
4e7e988f70
commit
3eeb02f1be
2 changed files with 32 additions and 2 deletions
|
|
@ -1055,6 +1055,12 @@ def _optimize_pre(model, qtype=None):
|
||||||
model.llm.config.model_type = "minicpm"
|
model.llm.config.model_type = "minicpm"
|
||||||
_optimize_pre(model.llm, qtype=qtype)
|
_optimize_pre(model.llm, qtype=qtype)
|
||||||
model.llm.config.model_type = "minicpmv"
|
model.llm.config.model_type = "minicpmv"
|
||||||
|
elif model.config.model_type == "megrezo":
|
||||||
|
from ipex_llm.transformers.models.minicpmv import merge_qkv
|
||||||
|
model.vision.apply(merge_qkv)
|
||||||
|
model.llm.config.model_type = "llama"
|
||||||
|
_optimize_pre(model.llm, qtype=qtype)
|
||||||
|
model.llm.config.model_type = "megrezo"
|
||||||
elif model.config.model_type == "chatglm":
|
elif model.config.model_type == "chatglm":
|
||||||
if hasattr(model.config, 'padded_vocab_size') and model.config.padded_vocab_size == 65024:
|
if hasattr(model.config, 'padded_vocab_size') and model.config.padded_vocab_size == 65024:
|
||||||
# chatglm2 and chatglm3
|
# chatglm2 and chatglm3
|
||||||
|
|
@ -2202,5 +2208,29 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model.vpm, vpm_module.Idefics2VisionAttention, siglip_attention_forward)
|
convert_forward(model.vpm, vpm_module.Idefics2VisionAttention, siglip_attention_forward)
|
||||||
minicpmv_chat = minicpmv_chat_wrapper(module.MiniCPMV.chat)
|
minicpmv_chat = minicpmv_chat_wrapper(module.MiniCPMV.chat)
|
||||||
model.chat = MethodType(minicpmv_chat, model)
|
model.chat = MethodType(minicpmv_chat, model)
|
||||||
|
elif model.config.model_type == "megrezo":
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
from ipex_llm.transformers.models.minicpmv import minicpmv_generate_wrapper
|
||||||
|
minicpmv_generate = minicpmv_generate_wrapper(module.MegrezO.generate)
|
||||||
|
model.generate = MethodType(minicpmv_generate, model)
|
||||||
|
|
||||||
|
# vision
|
||||||
|
vpm_modeling_module_name = model.vision.vpm.__class__.__module__
|
||||||
|
vpm_module = importlib.import_module(vpm_modeling_module_name)
|
||||||
|
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
||||||
|
convert_forward(model.vision.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
|
||||||
|
|
||||||
|
# resampler
|
||||||
|
from ipex_llm.transformers.models.minicpmv import _in_projection_packed
|
||||||
|
resampler_module_name = model.vision.resampler.__class__.__module__
|
||||||
|
resampler_module = importlib.import_module(resampler_module_name)
|
||||||
|
resampler_module._in_projection_packed = _in_projection_packed
|
||||||
|
|
||||||
|
# llm
|
||||||
|
model.llm.config.model_type = "llama"
|
||||||
|
model.llm.config.rope_scaling = {"rope_type": "default"}
|
||||||
|
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
|
||||||
|
model.llm.config.model_type = "megrezo"
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
|
|
@ -198,8 +198,8 @@ def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, de
|
||||||
elif seq_length != kv_length and seq_length <= 32:
|
elif seq_length != kv_length and seq_length <= 32:
|
||||||
mask = None
|
mask = None
|
||||||
else:
|
else:
|
||||||
mask = torch.zeros([1, 1, 1, padding_kv_length], torch.finfo(dtype).min,
|
mask = torch.zeros([1, 1, 1, padding_kv_length], dtype=dtype, device=device)
|
||||||
dtype=dtype, device=device)
|
mask[:, :, kv_length:padding_kv_length] = torch.finfo(dtype).min
|
||||||
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
|
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
|
||||||
else:
|
else:
|
||||||
if seq_length != kv_length and seq_length <= 32:
|
if seq_length != kv_length and seq_length <= 32:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue