Use new sdp again (#11025)
This commit is contained in:
parent
7e29928865
commit
59df750326
5 changed files with 34 additions and 95 deletions
|
|
@ -28,10 +28,11 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
|
|||
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
from transformers.utils import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
|
@ -166,9 +167,8 @@ def baichuan_attention_forward_7b_quantized(
|
|||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.dtype == torch.bool:
|
||||
attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
||||
invalidInputError(attention_mask is None or attention_mask.dtype != torch.bool,
|
||||
"attention_mask's dtype cannot be bool")
|
||||
|
||||
scaling_factor = 1 / math.sqrt(query_states.size(-1))
|
||||
if query_states.size(2) != 1 or device.type != 'xpu':
|
||||
|
|
@ -279,6 +279,9 @@ def baichuan_attention_forward_7b_origin(
|
|||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
invalidInputError(attention_mask is None or attention_mask.dtype != torch.bool,
|
||||
"attention_mask's dtype cannot be bool")
|
||||
|
||||
if xops is not None and self.training:
|
||||
attn_weights = None
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
|
@ -296,17 +299,12 @@ def baichuan_attention_forward_7b_origin(
|
|||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||
key_states,
|
||||
value_states)
|
||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
else:
|
||||
if attention_mask is not None:
|
||||
if attention_mask.dtype == torch.bool:
|
||||
attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
||||
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||
q_len, kv_seq_len, output_attentions):
|
||||
attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states,
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
|
|||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||
from ipex_llm.transformers.models.utils import use_esimd_sdp
|
||||
from ipex_llm.transformers.models.utils import use_sdp
|
||||
|
||||
|
||||
import os
|
||||
|
|
@ -558,17 +558,7 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
|
|||
value_layer,
|
||||
is_causal=True).to(key_layer.dtype)
|
||||
else:
|
||||
if use_esimd_sdp(query_layer.shape[2], key_layer.shape[2],
|
||||
query_layer.shape[-1], query_layer):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_layer,
|
||||
key_layer,
|
||||
value_layer)
|
||||
context_layer = attn_output.view(query_layer.shape)
|
||||
else:
|
||||
head_dim = query_layer.size(-1)
|
||||
attn = torch.matmul(query_layer.to(key_layer.dtype),
|
||||
key_layer.transpose(2, 3)) / math.sqrt(head_dim)
|
||||
# attention_mask is not None only when past_key_value is not None and q_len > 1
|
||||
if attention_mask is not None:
|
||||
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
|
||||
device=query_layer.device)
|
||||
|
|
@ -577,6 +567,19 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
|
|||
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
||||
else:
|
||||
attn_bias += attention_mask
|
||||
else:
|
||||
attn_bias = None
|
||||
|
||||
if use_sdp(query_layer.shape[2], key_layer.shape[2],
|
||||
query_layer.shape[-1], query_layer):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp(query_layer, key_layer, value_layer, attn_bias)
|
||||
context_layer = attn_output.view(query_layer.shape)
|
||||
else:
|
||||
head_dim = query_layer.size(-1)
|
||||
attn = torch.matmul(query_layer.to(key_layer.dtype),
|
||||
key_layer.transpose(2, 3)) / math.sqrt(head_dim)
|
||||
if attn_bias is not None:
|
||||
attn += attn_bias
|
||||
attn = F.softmax(attn, dim=-1,
|
||||
dtype=torch.float32).to(value_layer.dtype)
|
||||
|
|
|
|||
|
|
@ -41,8 +41,8 @@ from ipex_llm.transformers.models.utils import (
|
|||
apply_rotary_pos_emb_cache_freq_xpu
|
||||
)
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_quantize_kv_cache
|
||||
from ipex_llm.transformers.models.utils import use_sdp_fp8, restore_fp8_kv_cache
|
||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
||||
|
||||
from typing import Optional, Tuple, List
|
||||
|
|
@ -144,7 +144,7 @@ def attention_forward(
|
|||
attention_mask)
|
||||
else:
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
||||
elif use_sdp_causal(q_len, kv_seq_len, query_states, self.training):
|
||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||
import linear_q4_0
|
||||
if isinstance(past_key_value, DynamicFp8Cache):
|
||||
attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states)
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
|||
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeModel, apply_rotary_pos_emb
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache
|
||||
|
|
|
|||
|
|
@ -318,69 +318,6 @@ def use_flash_attention(query, key, attention_mask=None):
|
|||
return True
|
||||
|
||||
|
||||
def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
|
||||
if head_dim != 128:
|
||||
# esimd_sdp only support head_dim = 128 now
|
||||
return False
|
||||
elif q_len != 1:
|
||||
# esimd_sdp only support rest token and q_len == 1 now
|
||||
return False
|
||||
elif k_len < 8:
|
||||
# esimd_sdp will cause wrong output when k_len < 8
|
||||
return False
|
||||
elif query_states.device.type != "xpu":
|
||||
# esimd_sdp only support GPU now
|
||||
return False
|
||||
elif query_states.dtype != torch.float16:
|
||||
# esimd_sdp only has optimization for FP16 now
|
||||
return False
|
||||
|
||||
device_name = torch.xpu.get_device_name(query_states.device.index)
|
||||
if device_name.startswith("Intel(R) Arc(TM) A") or \
|
||||
device_name.startswith("Intel(R) Data Center GPU Flex") or \
|
||||
device_name.startswith("Intel(R) Data Center GPU Max"):
|
||||
import linear_fp16_esimd
|
||||
if not hasattr(linear_fp16_esimd, "sdp_forward"):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"):
|
||||
# esimd_sdp not support PVC GPU when batch size > 1 for now
|
||||
return False
|
||||
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \
|
||||
and is_deepspeed_available:
|
||||
# esimd_sdp not support ARC GPU when batch size > 1 using DeepSpeed AutoTP for now
|
||||
return False
|
||||
if query_states.shape[0] > 1 and attention_mask is not None:
|
||||
# for batched input, can't accept attention_mask
|
||||
# TODO: this check needs some time
|
||||
if not torch.all(attention_mask.eq(0)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
|
||||
if query_states.device.type != "xpu":
|
||||
# esimd_sdp only support GPU now
|
||||
return False
|
||||
elif query_states.dtype != torch.float16:
|
||||
# esimd_sdp only has optimization for FP16 now
|
||||
return False
|
||||
elif head_dim not in [64, 96, 128]:
|
||||
# esimd_sdp only support head_dim = 128 and 64 now
|
||||
return False
|
||||
elif q_len == k_len:
|
||||
# new sdp_fp16 only support rest token now
|
||||
return False
|
||||
elif q_len > 32:
|
||||
# Use new sdp_fp16 only when q_len <= 32
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def use_sdp(q_len, kv_len, head_dim, query_states):
|
||||
return (
|
||||
query_states.device.type == "xpu"
|
||||
|
|
@ -400,9 +337,10 @@ def use_sdp_fp8(q_len, kv_len, query_states):
|
|||
)
|
||||
|
||||
|
||||
def use_sdp_causal(q_len, kv_len, query_states, training):
|
||||
def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
|
||||
return (
|
||||
q_len == kv_len # first token
|
||||
and head_dim in [64, 96, 128] # for now
|
||||
and query_states.device.type == "xpu" # GPU
|
||||
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
|
||||
and not query_states.requires_grad and not training # not training
|
||||
|
|
|
|||
Loading…
Reference in a new issue