support new model (#12523)
This commit is contained in:
parent
922958c018
commit
77404d2a63
2 changed files with 15 additions and 1 deletions
|
|
@ -1049,6 +1049,10 @@ def _optimize_pre(model, qtype=None):
|
|||
model.llm.config.model_type = "qwen2"
|
||||
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
|
||||
model.llm.config.model_type = "llama"
|
||||
elif model.config.hidden_size == 1536 and model.config.vocab_size == 73464:
|
||||
from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq
|
||||
model.llm.apply(pre_compute_inv_freq)
|
||||
model.llm.config.model_type = "minicpm"
|
||||
_optimize_pre(model.llm, qtype=qtype)
|
||||
model.llm.config.model_type = "minicpmv"
|
||||
elif model.config.model_type == "chatglm":
|
||||
|
|
@ -2137,6 +2141,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
|
||||
# MiniCPM-V 2.5
|
||||
model.llm.config.model_type = "llama"
|
||||
elif model.config.hidden_size == 1536 and model.config.vocab_size == 73464:
|
||||
# MiniCPM-V ?
|
||||
model.llm.config.model_type = "minicpm"
|
||||
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
|
||||
model.llm.config.model_type = "minicpmv"
|
||||
|
||||
|
|
|
|||
|
|
@ -100,8 +100,15 @@ def minicpm_attention_forward(
|
|||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||
if self.rotary_emb.__class__.__name__ == "MiniCPMLongRoPE":
|
||||
if kv_seq_len > self.rotary_emb.original_max_position_embeddings:
|
||||
inv_freq = self.rotary_emb.long_inv_freq
|
||||
else:
|
||||
inv_freq = self.rotary_emb.short_inv_freq
|
||||
else:
|
||||
inv_freq = self.rotary_emb.inv_freq
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
xe_addons.rotary_half_inplaced(inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
|
||||
|
|
|
|||
Loading…
Reference in a new issue