From d8808cc2e37b587048d9605e71ca5711a52b6d62 Mon Sep 17 00:00:00 2001 From: Qiyuan Gong Date: Fri, 9 Aug 2024 10:35:51 +0800 Subject: [PATCH] Mistral apply_rotary_pos_emb_no_cache_xpu use rope_theta from config (#11747) mistral-7B-instruct-v0.2 and mistral-7B-instruct-v0.1 use different rope_theta (0.2 is 1e, 0.1 is 1e5). Pass self.config.rope_theta to apply_rotary_pos_emb_no_cache_xpu to avoid output difference. --- .../src/ipex_llm/transformers/models/mistral.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 35d7abae..61e507b2 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -318,7 +318,8 @@ def mistral_attention_forward_quantized( query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, key_states, position_ids, - "mistral") + "mistral", + self.config.rope_theta) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -535,7 +536,8 @@ def mistral_attention_forward_original( query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, key_states, position_ids, - "mistral") + "mistral", + self.config.rope_theta) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -753,7 +755,8 @@ def mistral_attention_forward_4_36_quantized( query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, key_states, position_ids, - "mistral") + "mistral", + self.config.rope_theta) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -994,7 +997,8 @@ def mistral_attention_forward_4_36_original( query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, key_states, position_ids, - "mistral") + "mistral", + self.config.rope_theta) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -1248,7 +1252,8 @@ def mistral_attention_forward_4_39_original( query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, key_states, position_ids, - "mistral") + "mistral", + self.config.rope_theta) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states,