support new model (#12523)

This commit is contained in:
Yishuo Wang 2024-12-11 13:41:15 +08:00 committed by GitHub
parent 922958c018
commit 77404d2a63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 15 additions and 1 deletions

View file

@ -1049,6 +1049,10 @@ def _optimize_pre(model, qtype=None):
model.llm.config.model_type = "qwen2"
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
model.llm.config.model_type = "llama"
elif model.config.hidden_size == 1536 and model.config.vocab_size == 73464:
from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq
model.llm.apply(pre_compute_inv_freq)
model.llm.config.model_type = "minicpm"
_optimize_pre(model.llm, qtype=qtype)
model.llm.config.model_type = "minicpmv"
elif model.config.model_type == "chatglm":
@ -2137,6 +2141,9 @@ def _optimize_post(model, lightweight_bmm=False):
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
# MiniCPM-V 2.5
model.llm.config.model_type = "llama"
elif model.config.hidden_size == 1536 and model.config.vocab_size == 73464:
# MiniCPM-V ?
model.llm.config.model_type = "minicpm"
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
model.llm.config.model_type = "minicpmv"

View file

@ -100,8 +100,15 @@ def minicpm_attention_forward(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if should_use_fuse_rope(hidden_states, position_ids, self.training):
if self.rotary_emb.__class__.__name__ == "MiniCPMLongRoPE":
if kv_seq_len > self.rotary_emb.original_max_position_embeddings:
inv_freq = self.rotary_emb.long_inv_freq
else:
inv_freq = self.rotary_emb.short_inv_freq
else:
inv_freq = self.rotary_emb.inv_freq
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
xe_addons.rotary_half_inplaced(inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)