add sdp_causal for mistral 4.36 (#11686)
				
					
				
			* add sdp_causal for mistral * fix * update
This commit is contained in:
		
							parent
							
								
									45c730ff39
								
							
						
					
					
						commit
						736a7ef72e
					
				
					 2 changed files with 31 additions and 7 deletions
				
			
		| 
						 | 
				
			
			@ -1143,7 +1143,16 @@ def llama_attention_forward_4_41_quantized(
 | 
			
		|||
    if len(past_key_value.key_cache) <= self.layer_idx:
 | 
			
		||||
        repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
        repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
        if should_split_qkv_tensor(query_states, bsz, self.num_heads,
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            cache_kwargs = None
 | 
			
		||||
            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                             self.layer_idx, cache_kwargs)
 | 
			
		||||
        if use_cache and use_sdp_causal(q_len, kv_seq_len, self.head_dim,
 | 
			
		||||
                                        query_states, self.training):
 | 
			
		||||
            import xe_addons
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
			
		||||
                                                   value_states, attention_mask)
 | 
			
		||||
        elif should_split_qkv_tensor(query_states, bsz, self.num_heads,
 | 
			
		||||
                                     q_len, kv_seq_len, output_attentions):
 | 
			
		||||
            attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
 | 
			
		||||
                                                         repeated_value_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -1184,10 +1193,6 @@ def llama_attention_forward_4_41_quantized(
 | 
			
		|||
                attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
			
		||||
                                                     dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
            attn_output = torch.matmul(attn_weights, repeated_value_states)
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            cache_kwargs = None
 | 
			
		||||
            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                             self.layer_idx, cache_kwargs)
 | 
			
		||||
    else:
 | 
			
		||||
        cache_kwargs = None  # Specific to RoPE models
 | 
			
		||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -52,7 +52,8 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
 | 
			
		|||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
			
		||||
    is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \
 | 
			
		||||
    use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
 | 
			
		||||
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
 | 
			
		||||
from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv
 | 
			
		||||
| 
						 | 
				
			
			@ -599,6 +600,15 @@ def mistral_attention_forward_original(
 | 
			
		|||
        attn_weights = None
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
    elif use_sdp_causal(q_len, key_states.shape[2], self.head_dim,
 | 
			
		||||
                        query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(),
 | 
			
		||||
                                           value_states.contiguous(), attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
    elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        # new fp16 sdp doesn't require repeat_kv
 | 
			
		||||
        import xe_addons
 | 
			
		||||
| 
						 | 
				
			
			@ -1052,6 +1062,15 @@ def mistral_attention_forward_4_36_original(
 | 
			
		|||
        attn_weights = None
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
    elif use_sdp_causal(q_len, key_states.shape[2], self.head_dim,
 | 
			
		||||
                        query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(),
 | 
			
		||||
                                           value_states.contiguous(), attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
    elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        # new fp16 sdp doesn't require repeat_kv
 | 
			
		||||
        import xe_addons
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue