Mistral apply_rotary_pos_emb_no_cache_xpu use rope_theta from config (#11747)

mistral-7B-instruct-v0.2 and mistral-7B-instruct-v0.1 use different rope_theta (0.2 is 1e, 0.1 is 1e5). Pass self.config.rope_theta to apply_rotary_pos_emb_no_cache_xpu to avoid output difference.
This commit is contained in:
Qiyuan Gong 2024-08-09 10:35:51 +08:00 committed by GitHub
parent 044e486480
commit d8808cc2e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -318,7 +318,8 @@ def mistral_attention_forward_quantized(
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states, key_states,
position_ids, position_ids,
"mistral") "mistral",
self.config.rope_theta)
else: else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -535,7 +536,8 @@ def mistral_attention_forward_original(
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states, key_states,
position_ids, position_ids,
"mistral") "mistral",
self.config.rope_theta)
else: else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -753,7 +755,8 @@ def mistral_attention_forward_4_36_quantized(
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states, key_states,
position_ids, position_ids,
"mistral") "mistral",
self.config.rope_theta)
else: else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -994,7 +997,8 @@ def mistral_attention_forward_4_36_original(
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states, key_states,
position_ids, position_ids,
"mistral") "mistral",
self.config.rope_theta)
else: else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -1248,7 +1252,8 @@ def mistral_attention_forward_4_39_original(
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states, key_states,
position_ids, position_ids,
"mistral") "mistral",
self.config.rope_theta)
else: else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, query_states, key_states = apply_rotary_pos_emb(query_states, key_states,