use new sdp and fp32 sdp (#11007)
This commit is contained in:
		
							parent
							
								
									8010af700f
								
							
						
					
					
						commit
						170e3d65e0
					
				
					 11 changed files with 62 additions and 62 deletions
				
			
		| 
						 | 
				
			
			@ -27,7 +27,7 @@ from torch import nn
 | 
			
		|||
import torch.nn.functional as F
 | 
			
		||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
			
		||||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
 | 
			
		||||
| 
						 | 
				
			
			@ -276,11 +276,9 @@ def baichuan_attention_forward_7b_origin(
 | 
			
		|||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
			
		||||
                                                    key_states,
 | 
			
		||||
                                                    value_states)
 | 
			
		||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,7 +50,7 @@ from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
 | 
			
		|||
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
 | 
			
		||||
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
			
		||||
| 
						 | 
				
			
			@ -420,9 +420,13 @@ def cohere_attention_forward_origin(
 | 
			
		|||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
 | 
			
		||||
        else:
 | 
			
		||||
            causal_mask = None
 | 
			
		||||
        attn_output = linear_q4_0.sdp(query_states, key_states, value_states, causal_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,8 +46,7 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
 | 
			
		|||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
			
		||||
    apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
 | 
			
		||||
    use_sdp_fp8
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
 | 
			
		||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
| 
						 | 
				
			
			@ -673,9 +672,9 @@ def llama_attention_forward_4_31_original(
 | 
			
		|||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			@ -1348,9 +1347,9 @@ def llama_attention_forward_4_36_original(
 | 
			
		|||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -52,8 +52,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
 | 
			
		|||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
			
		||||
    is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
 | 
			
		||||
    use_sdp_fp8
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
 | 
			
		||||
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
 | 
			
		||||
from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv
 | 
			
		||||
| 
						 | 
				
			
			@ -591,10 +590,10 @@ def mistral_attention_forward_original(
 | 
			
		|||
        attn_weights = None
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
    elif use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
    elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        # new fp16 sdp doesn't require repeat_kv
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
| 
						 | 
				
			
			@ -1032,10 +1031,10 @@ def mistral_attention_forward_4_36_original(
 | 
			
		|||
        attn_weights = None
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
    elif use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
    elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        # new fp16 sdp doesn't require repeat_kv
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -55,7 +55,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb,\
 | 
			
		|||
    apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.mistral import should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
 | 
			
		||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
 | 
			
		||||
from ipex_llm.transformers.low_bit_linear import IQ2_XXS
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -332,12 +332,9 @@ def mixtral_attention_forward(
 | 
			
		|||
                                                     value_states,
 | 
			
		||||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif use_esimd_sdp(query_states.shape[2], key_states.shape[2],
 | 
			
		||||
                       self.head_dim, query_states):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
			
		||||
                                                    key_states,
 | 
			
		||||
                                                    value_states)
 | 
			
		||||
    elif use_sdp(query_states.shape[2], key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -40,9 +40,8 @@ from ipex_llm.transformers.models.utils import (
 | 
			
		|||
    apply_rotary_pos_emb_cache_freq_xpu
 | 
			
		||||
)
 | 
			
		||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_new_esimd_sdp_fp16, use_quantize_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_quantize_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp_fp8, restore_fp8_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple, List
 | 
			
		||||
| 
						 | 
				
			
			@ -142,9 +141,9 @@ def attention_forward(
 | 
			
		|||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
    elif (isinstance(past_key_value, DynamicNormalCache) and
 | 
			
		||||
            use_new_esimd_sdp_fp16(q_len, kv_seq_len, self.head_dim, query_states)):
 | 
			
		||||
            use_sdp(q_len, kv_seq_len, self.head_dim, query_states)):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,8 +42,7 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
 | 
			
		|||
from ipex_llm.transformers.models.utils import rotate_half, SILU
 | 
			
		||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
 | 
			
		||||
    use_sdp_fp8
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError, invalidOperationError
 | 
			
		||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
| 
						 | 
				
			
			@ -291,9 +290,9 @@ def qwen_attention_forward_original(
 | 
			
		|||
        attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_new_esimd_sdp_fp16(q_len, key.shape[2], self.head_dim, query):
 | 
			
		||||
            use_sdp(q_len, key.shape[2], self.head_dim, query):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp_fp16(query, key, value, attention_mask)
 | 
			
		||||
        attn_output = linear_q4_0.sdp(query, key, value, attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query.shape)
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
        attn_weight = None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -52,7 +52,7 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
 | 
			
		|||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
 | 
			
		||||
| 
						 | 
				
			
			@ -565,11 +565,9 @@ def qwen2_attention_forward_origin(
 | 
			
		|||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
			
		||||
                                                    key_states,
 | 
			
		||||
                                                    value_states)
 | 
			
		||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -32,7 +32,7 @@ import torch.utils.checkpoint
 | 
			
		|||
from transformers.utils import logging
 | 
			
		||||
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import rotate_half
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_esimd_sdp
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
| 
						 | 
				
			
			@ -207,11 +207,9 @@ def qwen_attention_forward_vl(
 | 
			
		|||
    query = query.permute(0, 2, 1, 3)
 | 
			
		||||
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_esimd_sdp(q_len, key.shape[2], self.head_dim, query):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query,
 | 
			
		||||
                                                    key,
 | 
			
		||||
                                                    value)
 | 
			
		||||
            use_sdp(q_len, key.shape[2], self.head_dim, query):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp(query, key, value, attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query.shape)
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
        attn_weight = None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -53,7 +53,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
 | 
			
		|||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
 | 
			
		||||
    restore_fp8_kv_cache, use_quantize_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
 | 
			
		||||
from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv
 | 
			
		||||
try:
 | 
			
		||||
    from transformers.cache_utils import Cache
 | 
			
		||||
| 
						 | 
				
			
			@ -266,11 +266,9 @@ def stablelm_attention_forward_original(
 | 
			
		|||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states, attention_mask):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
			
		||||
                                                    key_states,
 | 
			
		||||
                                                    value_states)
 | 
			
		||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -375,20 +375,31 @@ def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
 | 
			
		|||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_sdp_fp8(q_len, k_len, query_states):
 | 
			
		||||
    if query_states.device.type != "xpu":
 | 
			
		||||
        return False
 | 
			
		||||
    if q_len == k_len:
 | 
			
		||||
        # sdp_fp8 only support rest token now
 | 
			
		||||
        return False
 | 
			
		||||
    return True
 | 
			
		||||
def use_sdp(q_len, kv_len, head_dim, query_states):
 | 
			
		||||
    return (
 | 
			
		||||
        query_states.device.type == "xpu"
 | 
			
		||||
        and query_states.dtype in [torch.float, torch.half]     # fp32/fp16
 | 
			
		||||
        and head_dim in [64, 96, 128]
 | 
			
		||||
        and q_len != kv_len     # next token
 | 
			
		||||
        and q_len <= 32         # lookup
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_sdp_fp8(q_len, kv_len, query_states):
 | 
			
		||||
    return (
 | 
			
		||||
        query_states.device.type == "xpu"
 | 
			
		||||
        and query_states.dtype in [torch.float, torch.half]     # fp32/fp16
 | 
			
		||||
        and q_len != kv_len     # next token
 | 
			
		||||
        and q_len <= 32         # lookup
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_sdp_causal(q_len, kv_len, query_states, training):
 | 
			
		||||
    return (
 | 
			
		||||
        q_len == kv_len     # first token
 | 
			
		||||
        and query_states.device.type == "xpu"   # GPU
 | 
			
		||||
        and not query_states.requires_grad and not training     # no training
 | 
			
		||||
        and query_states.dtype in [torch.float, torch.half]     # fp32/fp16
 | 
			
		||||
        and not query_states.requires_grad and not training     # not training
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue