Use new fp16 sdp in Qwen and modify the constraint (#10882)

This commit is contained in:
Yina Chen 2024-04-25 19:23:37 +08:00 committed by GitHub
parent 0213c1c1da
commit 8811f268ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 6 additions and 14 deletions

View file

@ -42,7 +42,8 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
from ipex_llm.transformers.models.utils import rotate_half, SILU from ipex_llm.transformers.models.utils import rotate_half, SILU
from ipex_llm.transformers.models.utils import mlp_fusion_check from ipex_llm.transformers.models.utils import mlp_fusion_check
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
use_sdp_fp8
from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.transformers.models.utils import use_decoding_fast_path
from ipex_llm.utils.common import invalidInputError, invalidOperationError from ipex_llm.utils.common import invalidInputError, invalidOperationError
from ipex_llm.ggml.quantize import ggml_tensor_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype
@ -290,11 +291,9 @@ def qwen_attention_forward_original(
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
attn_weights = None attn_weights = None
elif not self.training and not hidden_states.requires_grad and \ elif not self.training and not hidden_states.requires_grad and \
use_esimd_sdp(q_len, key.shape[2], self.head_dim, query): use_new_esimd_sdp_fp16(q_len, key.shape[2], self.head_dim, query):
import linear_fp16_esimd import linear_q4_0
attn_output = linear_fp16_esimd.sdp_forward(query, attn_output = linear_q4_0.sdp_fp16(query, key, value, attention_mask)
key,
value)
attn_output = attn_output.view(query.shape) attn_output = attn_output.view(query.shape)
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
attn_weight = None attn_weight = None
@ -485,7 +484,7 @@ def qwen_attention_forward_quantized(
def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None): def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
if query.size(2) != 1 or query.device.type != 'xpu': if not use_sdp_fp8(query.size(2), key.size(2), query):
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32 # We have no CPU fp8 matmul implementation for now, so just upscale to fp32
key, value = restore_fp8_kv_cache(key, value, query.dtype) key, value = restore_fp8_kv_cache(key, value, query.dtype)
attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = torch.matmul(query, key.transpose(-1, -2))

View file

@ -380,13 +380,6 @@ def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
# Use new sdp_fp16 only when q_len <= 32 # Use new sdp_fp16 only when q_len <= 32
return False return False
device_name = torch.xpu.get_device_name(query_states.device.index)
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \
and is_deepspeed_available:
# It seems there is an issue in DeepSpeed AutoTP when multi-card inference,
# Disable new sdp_fp16 for now
return False
return True return True