diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index cd4054b1..a98f02a4 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -414,10 +414,12 @@ def llama_attention_forward_4_31_quantized( kv_seq_len += past_key_value[0].shape[-2] if use_fuse_rope: + rope_theta = self.rotary_emb.base query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, key_states, position_ids, - "llama") + "llama", + rope_theta=rope_theta) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -603,10 +605,12 @@ def llama_attention_forward_4_31_original( kv_seq_len += past_key_value[0].shape[-2] if use_fuse_rope: + rope_theta = self.rotary_emb.base query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, key_states, position_ids, - "llama") + "llama", + rope_theta=rope_theta) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -795,10 +799,12 @@ def llama_attention_selective_batching_forward_4_31( kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value) if use_fuse_rope: + rope_theta = self.rotary_emb.base query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, key_states, position_ids, - "llama") + "llama", + rope_theta=rope_theta) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -1006,10 +1012,12 @@ def llama_attention_forward_4_36_quantized( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if use_fuse_rope: + rope_theta = self.rotary_emb.base query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, key_states, position_ids, - "llama") + "llama", + rope_theta=rope_theta) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -1266,10 +1274,12 @@ def llama_attention_forward_4_36_original( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if use_fuse_rope: + rope_theta = self.rotary_emb.base query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, key_states, position_ids, - "llama") + "llama", + rope_theta=rope_theta) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states,