LLM: fix llama2 FP16 & bs>1 & autotp on PVC and ARC (#10611)

This commit is contained in:
binbin Deng 2024-04-03 09:28:04 +08:00 committed by GitHub
parent 654dc5ba57
commit 2bbd8a1548
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 17 additions and 12 deletions

View file

@ -533,12 +533,12 @@ def llama_attention_forward_4_31_original(
self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
torch.xpu.empty_cache()
query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
device=hidden_states.device)
key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
device=hidden_states.device)
value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
device=hidden_states.device)
query_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
dtype=hidden_states.dtype, device=hidden_states.device)
key_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
dtype=hidden_states.dtype, device=hidden_states.device)
value_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
dtype=hidden_states.dtype, device=hidden_states.device)
torch.ops.torch_ipex.mm_qkv_out(
hidden_states, self.qkv_proj_weight, None,
query_states, key_states, value_states
@ -1165,12 +1165,12 @@ def llama_attention_forward_4_36_original(
self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
torch.xpu.empty_cache()
query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
device=hidden_states.device)
key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
device=hidden_states.device)
value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
device=hidden_states.device)
query_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
dtype=hidden_states.dtype, device=hidden_states.device)
key_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
dtype=hidden_states.dtype, device=hidden_states.device)
value_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
dtype=hidden_states.dtype, device=hidden_states.device)
torch.ops.torch_ipex.mm_qkv_out(
hidden_states, self.qkv_proj_weight, None,
query_states, key_states, value_states

View file

@ -21,6 +21,7 @@ from ipex_llm.utils.common import invalidInputError
from ipex_llm.ggml.quantize import ggml_tensor_qtype
from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4
from ipex_llm.transformers.convert import is_deepspeed_available
FP8_KV_ALLOC_LENGTH = 512
@ -341,6 +342,10 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"):
# esimd_sdp not support PVC GPU when batch size > 1 for now
return False
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \
and is_deepspeed_available:
# esimd_sdp not support ARC GPU when batch size > 1 using DeepSpeed AutoTP for now
return False
if query_states.shape[0] > 1 and attention_mask is not None:
# for batched input, can't accept attention_mask
# TODO: this check needs some time