Use new sdp again (#11025)
This commit is contained in:
parent
7e29928865
commit
59df750326
5 changed files with 34 additions and 95 deletions
|
|
@ -28,10 +28,11 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
|
||||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
||||||
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -166,9 +167,8 @@ def baichuan_attention_forward_7b_quantized(
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
if attention_mask is not None:
|
invalidInputError(attention_mask is None or attention_mask.dtype != torch.bool,
|
||||||
if attention_mask.dtype == torch.bool:
|
"attention_mask's dtype cannot be bool")
|
||||||
attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
|
||||||
|
|
||||||
scaling_factor = 1 / math.sqrt(query_states.size(-1))
|
scaling_factor = 1 / math.sqrt(query_states.size(-1))
|
||||||
if query_states.size(2) != 1 or device.type != 'xpu':
|
if query_states.size(2) != 1 or device.type != 'xpu':
|
||||||
|
|
@ -279,6 +279,9 @@ def baichuan_attention_forward_7b_origin(
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
invalidInputError(attention_mask is None or attention_mask.dtype != torch.bool,
|
||||||
|
"attention_mask's dtype cannot be bool")
|
||||||
|
|
||||||
if xops is not None and self.training:
|
if xops is not None and self.training:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
|
|
@ -296,17 +299,12 @@ def baichuan_attention_forward_7b_origin(
|
||||||
is_causal=True)
|
is_causal=True)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
elif not self.training and not hidden_states.requires_grad and \
|
elif not self.training and not hidden_states.requires_grad and \
|
||||||
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
import linear_fp16_esimd
|
import linear_q4_0
|
||||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
||||||
key_states,
|
|
||||||
value_states)
|
|
||||||
attn_output = attn_output.view(query_states.shape)
|
attn_output = attn_output.view(query_states.shape)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
else:
|
else:
|
||||||
if attention_mask is not None:
|
|
||||||
if attention_mask.dtype == torch.bool:
|
|
||||||
attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
|
||||||
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||||
q_len, kv_seq_len, output_attentions):
|
q_len, kv_seq_len, output_attentions):
|
||||||
attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states,
|
attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states,
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import use_esimd_sdp
|
from ipex_llm.transformers.models.utils import use_sdp
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -558,25 +558,28 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
|
||||||
value_layer,
|
value_layer,
|
||||||
is_causal=True).to(key_layer.dtype)
|
is_causal=True).to(key_layer.dtype)
|
||||||
else:
|
else:
|
||||||
if use_esimd_sdp(query_layer.shape[2], key_layer.shape[2],
|
# attention_mask is not None only when past_key_value is not None and q_len > 1
|
||||||
query_layer.shape[-1], query_layer):
|
if attention_mask is not None:
|
||||||
import linear_fp16_esimd
|
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
|
||||||
attn_output = linear_fp16_esimd.sdp_forward(query_layer,
|
device=query_layer.device)
|
||||||
key_layer,
|
attention_mask = ~attention_mask
|
||||||
value_layer)
|
if attention_mask.dtype == torch.bool:
|
||||||
|
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
||||||
|
else:
|
||||||
|
attn_bias += attention_mask
|
||||||
|
else:
|
||||||
|
attn_bias = None
|
||||||
|
|
||||||
|
if use_sdp(query_layer.shape[2], key_layer.shape[2],
|
||||||
|
query_layer.shape[-1], query_layer):
|
||||||
|
import linear_q4_0
|
||||||
|
attn_output = linear_q4_0.sdp(query_layer, key_layer, value_layer, attn_bias)
|
||||||
context_layer = attn_output.view(query_layer.shape)
|
context_layer = attn_output.view(query_layer.shape)
|
||||||
else:
|
else:
|
||||||
head_dim = query_layer.size(-1)
|
head_dim = query_layer.size(-1)
|
||||||
attn = torch.matmul(query_layer.to(key_layer.dtype),
|
attn = torch.matmul(query_layer.to(key_layer.dtype),
|
||||||
key_layer.transpose(2, 3)) / math.sqrt(head_dim)
|
key_layer.transpose(2, 3)) / math.sqrt(head_dim)
|
||||||
if attention_mask is not None:
|
if attn_bias is not None:
|
||||||
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
|
|
||||||
device=query_layer.device)
|
|
||||||
attention_mask = ~attention_mask
|
|
||||||
if attention_mask.dtype == torch.bool:
|
|
||||||
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
|
||||||
else:
|
|
||||||
attn_bias += attention_mask
|
|
||||||
attn += attn_bias
|
attn += attn_bias
|
||||||
attn = F.softmax(attn, dim=-1,
|
attn = F.softmax(attn, dim=-1,
|
||||||
dtype=torch.float32).to(value_layer.dtype)
|
dtype=torch.float32).to(value_layer.dtype)
|
||||||
|
|
|
||||||
|
|
@ -41,8 +41,8 @@ from ipex_llm.transformers.models.utils import (
|
||||||
apply_rotary_pos_emb_cache_freq_xpu
|
apply_rotary_pos_emb_cache_freq_xpu
|
||||||
)
|
)
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_quantize_kv_cache
|
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||||
from ipex_llm.transformers.models.utils import use_sdp_fp8, restore_fp8_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
||||||
|
|
||||||
from typing import Optional, Tuple, List
|
from typing import Optional, Tuple, List
|
||||||
|
|
@ -144,7 +144,7 @@ def attention_forward(
|
||||||
attention_mask)
|
attention_mask)
|
||||||
else:
|
else:
|
||||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
||||||
elif use_sdp_causal(q_len, kv_seq_len, query_states, self.training):
|
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
if isinstance(past_key_value, DynamicFp8Cache):
|
||||||
attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states)
|
attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states)
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
||||||
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb
|
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
|
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
|
||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
from ipex_llm.transformers.models.utils import use_flash_attention
|
||||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeModel, apply_rotary_pos_emb
|
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeModel, apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
from ipex_llm.transformers.kv import DynamicFp8Cache
|
from ipex_llm.transformers.kv import DynamicFp8Cache
|
||||||
|
|
|
||||||
|
|
@ -318,69 +318,6 @@ def use_flash_attention(query, key, attention_mask=None):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
|
|
||||||
if head_dim != 128:
|
|
||||||
# esimd_sdp only support head_dim = 128 now
|
|
||||||
return False
|
|
||||||
elif q_len != 1:
|
|
||||||
# esimd_sdp only support rest token and q_len == 1 now
|
|
||||||
return False
|
|
||||||
elif k_len < 8:
|
|
||||||
# esimd_sdp will cause wrong output when k_len < 8
|
|
||||||
return False
|
|
||||||
elif query_states.device.type != "xpu":
|
|
||||||
# esimd_sdp only support GPU now
|
|
||||||
return False
|
|
||||||
elif query_states.dtype != torch.float16:
|
|
||||||
# esimd_sdp only has optimization for FP16 now
|
|
||||||
return False
|
|
||||||
|
|
||||||
device_name = torch.xpu.get_device_name(query_states.device.index)
|
|
||||||
if device_name.startswith("Intel(R) Arc(TM) A") or \
|
|
||||||
device_name.startswith("Intel(R) Data Center GPU Flex") or \
|
|
||||||
device_name.startswith("Intel(R) Data Center GPU Max"):
|
|
||||||
import linear_fp16_esimd
|
|
||||||
if not hasattr(linear_fp16_esimd, "sdp_forward"):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"):
|
|
||||||
# esimd_sdp not support PVC GPU when batch size > 1 for now
|
|
||||||
return False
|
|
||||||
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \
|
|
||||||
and is_deepspeed_available:
|
|
||||||
# esimd_sdp not support ARC GPU when batch size > 1 using DeepSpeed AutoTP for now
|
|
||||||
return False
|
|
||||||
if query_states.shape[0] > 1 and attention_mask is not None:
|
|
||||||
# for batched input, can't accept attention_mask
|
|
||||||
# TODO: this check needs some time
|
|
||||||
if not torch.all(attention_mask.eq(0)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
|
|
||||||
if query_states.device.type != "xpu":
|
|
||||||
# esimd_sdp only support GPU now
|
|
||||||
return False
|
|
||||||
elif query_states.dtype != torch.float16:
|
|
||||||
# esimd_sdp only has optimization for FP16 now
|
|
||||||
return False
|
|
||||||
elif head_dim not in [64, 96, 128]:
|
|
||||||
# esimd_sdp only support head_dim = 128 and 64 now
|
|
||||||
return False
|
|
||||||
elif q_len == k_len:
|
|
||||||
# new sdp_fp16 only support rest token now
|
|
||||||
return False
|
|
||||||
elif q_len > 32:
|
|
||||||
# Use new sdp_fp16 only when q_len <= 32
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def use_sdp(q_len, kv_len, head_dim, query_states):
|
def use_sdp(q_len, kv_len, head_dim, query_states):
|
||||||
return (
|
return (
|
||||||
query_states.device.type == "xpu"
|
query_states.device.type == "xpu"
|
||||||
|
|
@ -400,9 +337,10 @@ def use_sdp_fp8(q_len, kv_len, query_states):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def use_sdp_causal(q_len, kv_len, query_states, training):
|
def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
|
||||||
return (
|
return (
|
||||||
q_len == kv_len # first token
|
q_len == kv_len # first token
|
||||||
|
and head_dim in [64, 96, 128] # for now
|
||||||
and query_states.device.type == "xpu" # GPU
|
and query_states.device.type == "xpu" # GPU
|
||||||
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
|
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
|
||||||
and not query_states.requires_grad and not training # not training
|
and not query_states.requires_grad and not training # not training
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue