Mistral apply_rotary_pos_emb_no_cache_xpu use rope_theta from config (#11747)
mistral-7B-instruct-v0.2 and mistral-7B-instruct-v0.1 use different rope_theta (0.2 is 1e, 0.1 is 1e5). Pass self.config.rope_theta to apply_rotary_pos_emb_no_cache_xpu to avoid output difference.
This commit is contained in:
		
							parent
							
								
									044e486480
								
							
						
					
					
						commit
						d8808cc2e3
					
				
					 1 changed files with 10 additions and 5 deletions
				
			
		| 
						 | 
				
			
			@ -318,7 +318,8 @@ def mistral_attention_forward_quantized(
 | 
			
		|||
            query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                         key_states,
 | 
			
		||||
                                                                         position_ids,
 | 
			
		||||
                                                                         "mistral")
 | 
			
		||||
                                                                         "mistral",
 | 
			
		||||
                                                                         self.config.rope_theta)
 | 
			
		||||
        else:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -535,7 +536,8 @@ def mistral_attention_forward_original(
 | 
			
		|||
            query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                         key_states,
 | 
			
		||||
                                                                         position_ids,
 | 
			
		||||
                                                                         "mistral")
 | 
			
		||||
                                                                         "mistral",
 | 
			
		||||
                                                                         self.config.rope_theta)
 | 
			
		||||
        else:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -753,7 +755,8 @@ def mistral_attention_forward_4_36_quantized(
 | 
			
		|||
            query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                         key_states,
 | 
			
		||||
                                                                         position_ids,
 | 
			
		||||
                                                                         "mistral")
 | 
			
		||||
                                                                         "mistral",
 | 
			
		||||
                                                                         self.config.rope_theta)
 | 
			
		||||
        else:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -994,7 +997,8 @@ def mistral_attention_forward_4_36_original(
 | 
			
		|||
            query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                         key_states,
 | 
			
		||||
                                                                         position_ids,
 | 
			
		||||
                                                                         "mistral")
 | 
			
		||||
                                                                         "mistral",
 | 
			
		||||
                                                                         self.config.rope_theta)
 | 
			
		||||
        else:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -1248,7 +1252,8 @@ def mistral_attention_forward_4_39_original(
 | 
			
		|||
            query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                         key_states,
 | 
			
		||||
                                                                         position_ids,
 | 
			
		||||
                                                                         "mistral")
 | 
			
		||||
                                                                         "mistral",
 | 
			
		||||
                                                                         self.config.rope_theta)
 | 
			
		||||
        else:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue