Use sdp when rest token seq_len > 1 in llama & mistral (for lookup & spec) (#10790)
* update sdp condition * update * fix * update & test llama * mistral * fix style * update * fix style * remove pvc constrain * update ds on arc * fix style
This commit is contained in:
parent
844e18b1db
commit
dc27b3bc35
3 changed files with 49 additions and 11 deletions
|
|
@ -46,7 +46,8 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
|
|||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
|
||||
use_sdp_fp8
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
|
||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
|
|
@ -449,7 +450,7 @@ def llama_attention_forward_4_31_quantized(
|
|||
kv_seq_len = key_states.shape[-2]
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
|
|
@ -666,7 +667,7 @@ def llama_attention_forward_4_31_original(
|
|||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states, attention_mask):
|
||||
use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
|
|
@ -1074,7 +1075,7 @@ def llama_attention_forward_4_36_quantized(
|
|||
self.layer_idx, cache_kwargs,
|
||||
new_layout=True)
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)\
|
||||
|
|
@ -1342,7 +1343,7 @@ def llama_attention_forward_4_36_original(
|
|||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
|
|
|
|||
|
|
@ -52,7 +52,8 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
|||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||
is_enough_kv_cache_room_4_36
|
||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
|
||||
use_sdp_fp8
|
||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
|
||||
from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv
|
||||
|
|
@ -310,7 +311,7 @@ def mistral_attention_forward_quantized(
|
|||
kv_seq_len = key_states.shape[-2]
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
|
|
@ -503,7 +504,7 @@ def mistral_attention_forward_original(
|
|||
attn_weights = None
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
elif use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
# new fp16 sdp doesn't require repeat_kv
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
|
|
@ -687,7 +688,7 @@ def mistral_attention_forward_4_36_quantized(
|
|||
self.layer_idx, cache_kwargs,
|
||||
new_layout=True)
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
|
|
@ -896,7 +897,7 @@ def mistral_attention_forward_4_36_original(
|
|||
attn_weights = None
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
elif use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
# new fp16 sdp doesn't require repeat_kv
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
|
|
|
|||
|
|
@ -325,7 +325,7 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
|
|||
# esimd_sdp only support head_dim = 128 now
|
||||
return False
|
||||
elif q_len != 1:
|
||||
# esimd_sdp only support rest token now
|
||||
# esimd_sdp only support rest token and q_len == 1 now
|
||||
return False
|
||||
elif k_len < 8:
|
||||
# esimd_sdp will cause wrong output when k_len < 8
|
||||
|
|
@ -363,6 +363,42 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
|
|||
return True
|
||||
|
||||
|
||||
def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
|
||||
if query_states.device.type != "xpu":
|
||||
# esimd_sdp only support GPU now
|
||||
return False
|
||||
elif query_states.dtype != torch.float16:
|
||||
# esimd_sdp only has optimization for FP16 now
|
||||
return False
|
||||
elif head_dim != 128 and head_dim != 64:
|
||||
# esimd_sdp only support head_dim = 128 and 64 now
|
||||
return False
|
||||
elif q_len == k_len:
|
||||
# new sdp_fp16 only support rest token now
|
||||
return False
|
||||
elif q_len > 32:
|
||||
# Use new sdp_fp16 only when q_len <= 32
|
||||
return False
|
||||
|
||||
device_name = torch.xpu.get_device_name(query_states.device.index)
|
||||
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \
|
||||
and is_deepspeed_available:
|
||||
# It seems there is an issue in DeepSpeed AutoTP when multi-card inference,
|
||||
# Disable new sdp_fp16 for now
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def use_sdp_fp8(q_len, k_len, query_states):
|
||||
if query_states.device.type != "xpu":
|
||||
return False
|
||||
if q_len == k_len:
|
||||
# sdp_fp8 only support rest token now
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def mlp_fusion_check(x, qtype, training):
|
||||
invalidInputError(x.dim() == 2,
|
||||
"Here input x's dim should be 2.")
|
||||
|
|
|
|||
Loading…
Reference in a new issue