diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm.py b/python/llm/src/ipex_llm/transformers/models/minicpm.py index 23a394e1..1df3d7f5 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm.py @@ -241,12 +241,9 @@ def minicpm_attention_forward_original( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if use_fuse_rope: - rope_theta = self.rotary_emb.base - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "llama", - rope_theta=rope_theta) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: if cache_position is not None: # for transformers 4.38.0 @@ -313,7 +310,6 @@ def minicpm_attention_forward_original( is_causal=True) attn_weights = None elif not self.training and not hidden_states.requires_grad and \ - self.layer_idx > 0 and \ use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import xe_addons attn_output = xe_addons.sdp(query_states, key_states, value_states, @@ -450,12 +446,9 @@ def minicpm_attention_forward_quantized( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if use_fuse_rope: - rope_theta = self.rotary_emb.base - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "llama", - rope_theta=rope_theta) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: if cache_position is not None: # for transformers 4.38.0