Use new fp16 sdp in Qwen and modify the constraint (#10882)
This commit is contained in:
parent
0213c1c1da
commit
8811f268ff
2 changed files with 6 additions and 14 deletions
|
|
@ -42,7 +42,8 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
|
||||||
from ipex_llm.transformers.models.utils import rotate_half, SILU
|
from ipex_llm.transformers.models.utils import rotate_half, SILU
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
|
||||||
|
use_sdp_fp8
|
||||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||||
from ipex_llm.utils.common import invalidInputError, invalidOperationError
|
from ipex_llm.utils.common import invalidInputError, invalidOperationError
|
||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
|
@ -290,11 +291,9 @@ def qwen_attention_forward_original(
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
elif not self.training and not hidden_states.requires_grad and \
|
elif not self.training and not hidden_states.requires_grad and \
|
||||||
use_esimd_sdp(q_len, key.shape[2], self.head_dim, query):
|
use_new_esimd_sdp_fp16(q_len, key.shape[2], self.head_dim, query):
|
||||||
import linear_fp16_esimd
|
import linear_q4_0
|
||||||
attn_output = linear_fp16_esimd.sdp_forward(query,
|
attn_output = linear_q4_0.sdp_fp16(query, key, value, attention_mask)
|
||||||
key,
|
|
||||||
value)
|
|
||||||
attn_output = attn_output.view(query.shape)
|
attn_output = attn_output.view(query.shape)
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
attn_weight = None
|
attn_weight = None
|
||||||
|
|
@ -485,7 +484,7 @@ def qwen_attention_forward_quantized(
|
||||||
|
|
||||||
|
|
||||||
def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
|
def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
|
||||||
if query.size(2) != 1 or query.device.type != 'xpu':
|
if not use_sdp_fp8(query.size(2), key.size(2), query):
|
||||||
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
|
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
|
||||||
key, value = restore_fp8_kv_cache(key, value, query.dtype)
|
key, value = restore_fp8_kv_cache(key, value, query.dtype)
|
||||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||||
|
|
|
||||||
|
|
@ -380,13 +380,6 @@ def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
|
||||||
# Use new sdp_fp16 only when q_len <= 32
|
# Use new sdp_fp16 only when q_len <= 32
|
||||||
return False
|
return False
|
||||||
|
|
||||||
device_name = torch.xpu.get_device_name(query_states.device.index)
|
|
||||||
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \
|
|
||||||
and is_deepspeed_available:
|
|
||||||
# It seems there is an issue in DeepSpeed AutoTP when multi-card inference,
|
|
||||||
# Disable new sdp_fp16 for now
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue