Use new fp16 sdp in Qwen and modify the constraint (#10882)
This commit is contained in:
parent
0213c1c1da
commit
8811f268ff
2 changed files with 6 additions and 14 deletions
|
|
@ -42,7 +42,8 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
|
|||
from ipex_llm.transformers.models.utils import rotate_half, SILU
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
|
||||
use_sdp_fp8
|
||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||
from ipex_llm.utils.common import invalidInputError, invalidOperationError
|
||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||
|
|
@ -290,11 +291,9 @@ def qwen_attention_forward_original(
|
|||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_esimd_sdp(q_len, key.shape[2], self.head_dim, query):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query,
|
||||
key,
|
||||
value)
|
||||
use_new_esimd_sdp_fp16(q_len, key.shape[2], self.head_dim, query):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query, key, value, attention_mask)
|
||||
attn_output = attn_output.view(query.shape)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_weight = None
|
||||
|
|
@ -485,7 +484,7 @@ def qwen_attention_forward_quantized(
|
|||
|
||||
|
||||
def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
|
||||
if query.size(2) != 1 or query.device.type != 'xpu':
|
||||
if not use_sdp_fp8(query.size(2), key.size(2), query):
|
||||
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
|
||||
key, value = restore_fp8_kv_cache(key, value, query.dtype)
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
|
|
|||
|
|
@ -380,13 +380,6 @@ def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
|
|||
# Use new sdp_fp16 only when q_len <= 32
|
||||
return False
|
||||
|
||||
device_name = torch.xpu.get_device_name(query_states.device.index)
|
||||
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \
|
||||
and is_deepspeed_available:
|
||||
# It seems there is an issue in DeepSpeed AutoTP when multi-card inference,
|
||||
# Disable new sdp_fp16 for now
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue