Use new sdp again (#11025)
This commit is contained in:
		
							parent
							
								
									7e29928865
								
							
						
					
					
						commit
						59df750326
					
				
					 5 changed files with 34 additions and 95 deletions
				
			
		| 
						 | 
					@ -28,10 +28,11 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
 | 
				
			||||||
    restore_fp8_kv_cache, use_quantize_kv_cache
 | 
					    restore_fp8_kv_cache, use_quantize_kv_cache
 | 
				
			||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
					from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
				
			||||||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
					    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
					from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
 | 
				
			||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
 | 
					from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
 | 
				
			||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
					from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
				
			||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
 | 
					from ipex_llm.transformers.models.utils import mlp_fusion_check
 | 
				
			||||||
 | 
					from ipex_llm.utils.common.log4Error import invalidInputError
 | 
				
			||||||
from transformers.utils import logging
 | 
					from transformers.utils import logging
 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -166,9 +167,8 @@ def baichuan_attention_forward_7b_quantized(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
					    past_key_value = (key_states, value_states) if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if attention_mask is not None:
 | 
					    invalidInputError(attention_mask is None or attention_mask.dtype != torch.bool,
 | 
				
			||||||
        if attention_mask.dtype == torch.bool:
 | 
					                      "attention_mask's dtype cannot be bool")
 | 
				
			||||||
            attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    scaling_factor = 1 / math.sqrt(query_states.size(-1))
 | 
					    scaling_factor = 1 / math.sqrt(query_states.size(-1))
 | 
				
			||||||
    if query_states.size(2) != 1 or device.type != 'xpu':
 | 
					    if query_states.size(2) != 1 or device.type != 'xpu':
 | 
				
			||||||
| 
						 | 
					@ -279,6 +279,9 @@ def baichuan_attention_forward_7b_origin(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
					    past_key_value = (key_states, value_states) if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    invalidInputError(attention_mask is None or attention_mask.dtype != torch.bool,
 | 
				
			||||||
 | 
					                      "attention_mask's dtype cannot be bool")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if xops is not None and self.training:
 | 
					    if xops is not None and self.training:
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
        query_states = query_states.transpose(1, 2)
 | 
					        query_states = query_states.transpose(1, 2)
 | 
				
			||||||
| 
						 | 
					@ -296,17 +299,12 @@ def baichuan_attention_forward_7b_origin(
 | 
				
			||||||
                                                         is_causal=True)
 | 
					                                                         is_causal=True)
 | 
				
			||||||
            attn_weights = None
 | 
					            attn_weights = None
 | 
				
			||||||
        elif not self.training and not hidden_states.requires_grad and \
 | 
					        elif not self.training and not hidden_states.requires_grad and \
 | 
				
			||||||
                use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
					                use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
				
			||||||
            import linear_fp16_esimd
 | 
					            import linear_q4_0
 | 
				
			||||||
            attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
					            attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
				
			||||||
                                                        key_states,
 | 
					 | 
				
			||||||
                                                        value_states)
 | 
					 | 
				
			||||||
            attn_output = attn_output.view(query_states.shape)
 | 
					            attn_output = attn_output.view(query_states.shape)
 | 
				
			||||||
            attn_weights = None
 | 
					            attn_weights = None
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if attention_mask is not None:
 | 
					 | 
				
			||||||
                if attention_mask.dtype == torch.bool:
 | 
					 | 
				
			||||||
                    attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
 | 
					 | 
				
			||||||
            if should_split_qkv_tensor(query_states, bsz, self.num_heads,
 | 
					            if should_split_qkv_tensor(query_states, bsz, self.num_heads,
 | 
				
			||||||
                                       q_len, kv_seq_len, output_attentions):
 | 
					                                       q_len, kv_seq_len, output_attentions):
 | 
				
			||||||
                attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states,
 | 
					                attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,7 +25,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
				
			||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
					from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
 | 
					from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
 | 
				
			||||||
    restore_fp8_kv_cache, use_quantize_kv_cache
 | 
					    restore_fp8_kv_cache, use_quantize_kv_cache
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_esimd_sdp
 | 
					from ipex_llm.transformers.models.utils import use_sdp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
| 
						 | 
					@ -558,25 +558,28 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
 | 
				
			||||||
                                                               value_layer,
 | 
					                                                               value_layer,
 | 
				
			||||||
                                                               is_causal=True).to(key_layer.dtype)
 | 
					                                                               is_causal=True).to(key_layer.dtype)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if use_esimd_sdp(query_layer.shape[2], key_layer.shape[2],
 | 
					            # attention_mask is not None only when past_key_value is not None and q_len > 1
 | 
				
			||||||
                             query_layer.shape[-1], query_layer):
 | 
					            if attention_mask is not None:
 | 
				
			||||||
                import linear_fp16_esimd
 | 
					                attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
 | 
				
			||||||
                attn_output = linear_fp16_esimd.sdp_forward(query_layer,
 | 
					                                        device=query_layer.device)
 | 
				
			||||||
                                                            key_layer,
 | 
					                attention_mask = ~attention_mask
 | 
				
			||||||
                                                            value_layer)
 | 
					                if attention_mask.dtype == torch.bool:
 | 
				
			||||||
 | 
					                    attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    attn_bias += attention_mask
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                attn_bias = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if use_sdp(query_layer.shape[2], key_layer.shape[2],
 | 
				
			||||||
 | 
					                       query_layer.shape[-1], query_layer):
 | 
				
			||||||
 | 
					                import linear_q4_0
 | 
				
			||||||
 | 
					                attn_output = linear_q4_0.sdp(query_layer, key_layer, value_layer, attn_bias)
 | 
				
			||||||
                context_layer = attn_output.view(query_layer.shape)
 | 
					                context_layer = attn_output.view(query_layer.shape)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                head_dim = query_layer.size(-1)
 | 
					                head_dim = query_layer.size(-1)
 | 
				
			||||||
                attn = torch.matmul(query_layer.to(key_layer.dtype),
 | 
					                attn = torch.matmul(query_layer.to(key_layer.dtype),
 | 
				
			||||||
                                    key_layer.transpose(2, 3)) / math.sqrt(head_dim)
 | 
					                                    key_layer.transpose(2, 3)) / math.sqrt(head_dim)
 | 
				
			||||||
                if attention_mask is not None:
 | 
					                if attn_bias is not None:
 | 
				
			||||||
                    attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
 | 
					 | 
				
			||||||
                                            device=query_layer.device)
 | 
					 | 
				
			||||||
                    attention_mask = ~attention_mask
 | 
					 | 
				
			||||||
                    if attention_mask.dtype == torch.bool:
 | 
					 | 
				
			||||||
                        attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
 | 
					 | 
				
			||||||
                    else:
 | 
					 | 
				
			||||||
                        attn_bias += attention_mask
 | 
					 | 
				
			||||||
                    attn += attn_bias
 | 
					                    attn += attn_bias
 | 
				
			||||||
                attn = F.softmax(attn, dim=-1,
 | 
					                attn = F.softmax(attn, dim=-1,
 | 
				
			||||||
                                 dtype=torch.float32).to(value_layer.dtype)
 | 
					                                 dtype=torch.float32).to(value_layer.dtype)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -41,8 +41,8 @@ from ipex_llm.transformers.models.utils import (
 | 
				
			||||||
    apply_rotary_pos_emb_cache_freq_xpu
 | 
					    apply_rotary_pos_emb_cache_freq_xpu
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
 | 
					from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_quantize_kv_cache
 | 
					from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_sdp_fp8, restore_fp8_kv_cache
 | 
					from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
				
			||||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
 | 
					from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional, Tuple, List
 | 
					from typing import Optional, Tuple, List
 | 
				
			||||||
| 
						 | 
					@ -144,7 +144,7 @@ def attention_forward(
 | 
				
			||||||
                                              attention_mask)
 | 
					                                              attention_mask)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
					            attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
				
			||||||
    elif use_sdp_causal(q_len, kv_seq_len, query_states, self.training):
 | 
					    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
				
			||||||
        import linear_q4_0
 | 
					        import linear_q4_0
 | 
				
			||||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
					        if isinstance(past_key_value, DynamicFp8Cache):
 | 
				
			||||||
            attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states)
 | 
					            attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -52,7 +52,7 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
 | 
				
			||||||
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb
 | 
					from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb
 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
 | 
					from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
					from ipex_llm.transformers.models.utils import use_flash_attention
 | 
				
			||||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeModel, apply_rotary_pos_emb
 | 
					from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeModel, apply_rotary_pos_emb
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
					from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
				
			||||||
from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
					from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -318,69 +318,6 @@ def use_flash_attention(query, key, attention_mask=None):
 | 
				
			||||||
    return True
 | 
					    return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
 | 
					 | 
				
			||||||
    if head_dim != 128:
 | 
					 | 
				
			||||||
        # esimd_sdp only support head_dim = 128 now
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    elif q_len != 1:
 | 
					 | 
				
			||||||
        # esimd_sdp only support rest token and q_len == 1 now
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    elif k_len < 8:
 | 
					 | 
				
			||||||
        # esimd_sdp will cause wrong output when k_len < 8
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    elif query_states.device.type != "xpu":
 | 
					 | 
				
			||||||
        # esimd_sdp only support GPU now
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    elif query_states.dtype != torch.float16:
 | 
					 | 
				
			||||||
        # esimd_sdp only has optimization for FP16 now
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    device_name = torch.xpu.get_device_name(query_states.device.index)
 | 
					 | 
				
			||||||
    if device_name.startswith("Intel(R) Arc(TM) A") or \
 | 
					 | 
				
			||||||
       device_name.startswith("Intel(R) Data Center GPU Flex") or \
 | 
					 | 
				
			||||||
       device_name.startswith("Intel(R) Data Center GPU Max"):
 | 
					 | 
				
			||||||
        import linear_fp16_esimd
 | 
					 | 
				
			||||||
        if not hasattr(linear_fp16_esimd, "sdp_forward"):
 | 
					 | 
				
			||||||
            return False
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"):
 | 
					 | 
				
			||||||
        # esimd_sdp not support PVC GPU when batch size > 1 for now
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \
 | 
					 | 
				
			||||||
            and is_deepspeed_available:
 | 
					 | 
				
			||||||
        # esimd_sdp not support ARC GPU when batch size > 1 using DeepSpeed AutoTP for now
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    if query_states.shape[0] > 1 and attention_mask is not None:
 | 
					 | 
				
			||||||
        # for batched input, can't accept attention_mask
 | 
					 | 
				
			||||||
        # TODO: this check needs some time
 | 
					 | 
				
			||||||
        if not torch.all(attention_mask.eq(0)):
 | 
					 | 
				
			||||||
            return False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
 | 
					 | 
				
			||||||
    if query_states.device.type != "xpu":
 | 
					 | 
				
			||||||
        # esimd_sdp only support GPU now
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    elif query_states.dtype != torch.float16:
 | 
					 | 
				
			||||||
        # esimd_sdp only has optimization for FP16 now
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    elif head_dim not in [64, 96, 128]:
 | 
					 | 
				
			||||||
        # esimd_sdp only support head_dim = 128 and 64 now
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    elif q_len == k_len:
 | 
					 | 
				
			||||||
        # new sdp_fp16 only support rest token now
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    elif q_len > 32:
 | 
					 | 
				
			||||||
        # Use new sdp_fp16 only when q_len <= 32
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def use_sdp(q_len, kv_len, head_dim, query_states):
 | 
					def use_sdp(q_len, kv_len, head_dim, query_states):
 | 
				
			||||||
    return (
 | 
					    return (
 | 
				
			||||||
        query_states.device.type == "xpu"
 | 
					        query_states.device.type == "xpu"
 | 
				
			||||||
| 
						 | 
					@ -400,9 +337,10 @@ def use_sdp_fp8(q_len, kv_len, query_states):
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def use_sdp_causal(q_len, kv_len, query_states, training):
 | 
					def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
 | 
				
			||||||
    return (
 | 
					    return (
 | 
				
			||||||
        q_len == kv_len     # first token
 | 
					        q_len == kv_len     # first token
 | 
				
			||||||
 | 
					        and head_dim in [64, 96, 128]           # for now
 | 
				
			||||||
        and query_states.device.type == "xpu"   # GPU
 | 
					        and query_states.device.type == "xpu"   # GPU
 | 
				
			||||||
        and query_states.dtype in [torch.float, torch.half]     # fp32/fp16
 | 
					        and query_states.dtype in [torch.float, torch.half]     # fp32/fp16
 | 
				
			||||||
        and not query_states.requires_grad and not training     # not training
 | 
					        and not query_states.requires_grad and not training     # not training
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue