LLM: fix llama2 FP16 & bs>1 & autotp on PVC and ARC (#10611)
This commit is contained in:
parent
654dc5ba57
commit
2bbd8a1548
2 changed files with 17 additions and 12 deletions
|
|
@ -533,12 +533,12 @@ def llama_attention_forward_4_31_original(
|
||||||
self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
|
self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
|
||||||
self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
|
self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
query_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
|
||||||
device=hidden_states.device)
|
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
key_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
|
||||||
device=hidden_states.device)
|
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
value_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
|
||||||
device=hidden_states.device)
|
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
torch.ops.torch_ipex.mm_qkv_out(
|
torch.ops.torch_ipex.mm_qkv_out(
|
||||||
hidden_states, self.qkv_proj_weight, None,
|
hidden_states, self.qkv_proj_weight, None,
|
||||||
query_states, key_states, value_states
|
query_states, key_states, value_states
|
||||||
|
|
@ -1165,12 +1165,12 @@ def llama_attention_forward_4_36_original(
|
||||||
self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
|
self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
|
||||||
self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
|
self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
query_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
|
||||||
device=hidden_states.device)
|
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
key_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
|
||||||
device=hidden_states.device)
|
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
|
value_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
|
||||||
device=hidden_states.device)
|
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
torch.ops.torch_ipex.mm_qkv_out(
|
torch.ops.torch_ipex.mm_qkv_out(
|
||||||
hidden_states, self.qkv_proj_weight, None,
|
hidden_states, self.qkv_proj_weight, None,
|
||||||
query_states, key_states, value_states
|
query_states, key_states, value_states
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from ipex_llm.utils.common import invalidInputError
|
||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type
|
from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type
|
||||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4
|
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4
|
||||||
|
from ipex_llm.transformers.convert import is_deepspeed_available
|
||||||
|
|
||||||
FP8_KV_ALLOC_LENGTH = 512
|
FP8_KV_ALLOC_LENGTH = 512
|
||||||
|
|
||||||
|
|
@ -341,6 +342,10 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
|
||||||
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"):
|
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"):
|
||||||
# esimd_sdp not support PVC GPU when batch size > 1 for now
|
# esimd_sdp not support PVC GPU when batch size > 1 for now
|
||||||
return False
|
return False
|
||||||
|
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \
|
||||||
|
and is_deepspeed_available:
|
||||||
|
# esimd_sdp not support ARC GPU when batch size > 1 using DeepSpeed AutoTP for now
|
||||||
|
return False
|
||||||
if query_states.shape[0] > 1 and attention_mask is not None:
|
if query_states.shape[0] > 1 and attention_mask is not None:
|
||||||
# for batched input, can't accept attention_mask
|
# for batched input, can't accept attention_mask
|
||||||
# TODO: this check needs some time
|
# TODO: this check needs some time
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue