diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index e3674c6c..05ac94a4 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1049,6 +1049,10 @@ def _optimize_pre(model, qtype=None): model.llm.config.model_type = "qwen2" elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256: model.llm.config.model_type = "llama" + elif model.config.hidden_size == 1536 and model.config.vocab_size == 73464: + from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq + model.llm.apply(pre_compute_inv_freq) + model.llm.config.model_type = "minicpm" _optimize_pre(model.llm, qtype=qtype) model.llm.config.model_type = "minicpmv" elif model.config.model_type == "chatglm": @@ -2137,6 +2141,9 @@ def _optimize_post(model, lightweight_bmm=False): elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256: # MiniCPM-V 2.5 model.llm.config.model_type = "llama" + elif model.config.hidden_size == 1536 and model.config.vocab_size == 73464: + # MiniCPM-V ? + model.llm.config.model_type = "minicpm" _optimize_post(model.llm, lightweight_bmm=lightweight_bmm) model.llm.config.model_type = "minicpmv" diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm.py b/python/llm/src/ipex_llm/transformers/models/minicpm.py index 6e2cab0f..30a26277 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm.py @@ -100,8 +100,15 @@ def minicpm_attention_forward( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if should_use_fuse_rope(hidden_states, position_ids, self.training): + if self.rotary_emb.__class__.__name__ == "MiniCPMLongRoPE": + if kv_seq_len > self.rotary_emb.original_max_position_embeddings: + inv_freq = self.rotary_emb.long_inv_freq + else: + inv_freq = self.rotary_emb.short_inv_freq + else: + inv_freq = self.rotary_emb.inv_freq import xe_addons - xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + xe_addons.rotary_half_inplaced(inv_freq, position_ids, query_states, key_states) else: cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)