LLM: fix llama2 FP16 & bs>1 & autotp on PVC and ARC (#10611)
This commit is contained in:
parent
654dc5ba57
commit
2bbd8a1548
2 changed files with 17 additions and 12 deletions
|
|
@ -533,12 +533,12 @@ def llama_attention_forward_4_31_original(
|
|||
self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
|
||||
self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
|
||||
torch.xpu.empty_cache()
|
||||
query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
query_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
|
||||
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
key_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
|
||||
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
value_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
|
||||
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
torch.ops.torch_ipex.mm_qkv_out(
|
||||
hidden_states, self.qkv_proj_weight, None,
|
||||
query_states, key_states, value_states
|
||||
|
|
@ -1165,12 +1165,12 @@ def llama_attention_forward_4_36_original(
|
|||
self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
|
||||
self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
|
||||
torch.xpu.empty_cache()
|
||||
query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
query_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
|
||||
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
key_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
|
||||
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
value_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
|
||||
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
torch.ops.torch_ipex.mm_qkv_out(
|
||||
hidden_states, self.qkv_proj_weight, None,
|
||||
query_states, key_states, value_states
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from ipex_llm.utils.common import invalidInputError
|
|||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||
from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type
|
||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4
|
||||
from ipex_llm.transformers.convert import is_deepspeed_available
|
||||
|
||||
FP8_KV_ALLOC_LENGTH = 512
|
||||
|
||||
|
|
@ -341,6 +342,10 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
|
|||
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"):
|
||||
# esimd_sdp not support PVC GPU when batch size > 1 for now
|
||||
return False
|
||||
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \
|
||||
and is_deepspeed_available:
|
||||
# esimd_sdp not support ARC GPU when batch size > 1 using DeepSpeed AutoTP for now
|
||||
return False
|
||||
if query_states.shape[0] > 1 and attention_mask is not None:
|
||||
# for batched input, can't accept attention_mask
|
||||
# TODO: this check needs some time
|
||||
|
|
|
|||
Loading…
Reference in a new issue