From 2bbd8a15480d1822a78eaec95e5f753ff7144a56 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Wed, 3 Apr 2024 09:28:04 +0800 Subject: [PATCH] LLM: fix llama2 FP16 & bs>1 & autotp on PVC and ARC (#10611) --- .../src/ipex_llm/transformers/models/llama.py | 24 +++++++++---------- .../src/ipex_llm/transformers/models/utils.py | 5 ++++ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 25330fe0..a4044e1a 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -533,12 +533,12 @@ def llama_attention_forward_4_31_original( self.k_proj.weight.data = self.qkv_proj_weight[1, :, :] self.v_proj.weight.data = self.qkv_proj_weight[2, :, :] torch.xpu.empty_cache() - query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype, - device=hidden_states.device) - key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype, - device=hidden_states.device) - value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype, - device=hidden_states.device) + query_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1], + dtype=hidden_states.dtype, device=hidden_states.device) + key_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1], + dtype=hidden_states.dtype, device=hidden_states.device) + value_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1], + dtype=hidden_states.dtype, device=hidden_states.device) torch.ops.torch_ipex.mm_qkv_out( hidden_states, self.qkv_proj_weight, None, query_states, key_states, value_states @@ -1165,12 +1165,12 @@ def llama_attention_forward_4_36_original( self.k_proj.weight.data = self.qkv_proj_weight[1, :, :] self.v_proj.weight.data = self.qkv_proj_weight[2, :, :] torch.xpu.empty_cache() - query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype, - device=hidden_states.device) - key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype, - device=hidden_states.device) - value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype, - device=hidden_states.device) + query_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1], + dtype=hidden_states.dtype, device=hidden_states.device) + key_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1], + dtype=hidden_states.dtype, device=hidden_states.device) + value_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1], + dtype=hidden_states.dtype, device=hidden_states.device) torch.ops.torch_ipex.mm_qkv_out( hidden_states, self.qkv_proj_weight, None, query_states, key_states, value_states diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 3fa48489..405fdfd0 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -21,6 +21,7 @@ from ipex_llm.utils.common import invalidInputError from ipex_llm.ggml.quantize import ggml_tensor_qtype from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4 +from ipex_llm.transformers.convert import is_deepspeed_available FP8_KV_ALLOC_LENGTH = 512 @@ -341,6 +342,10 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None): if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"): # esimd_sdp not support PVC GPU when batch size > 1 for now return False + if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \ + and is_deepspeed_available: + # esimd_sdp not support ARC GPU when batch size > 1 using DeepSpeed AutoTP for now + return False if query_states.shape[0] > 1 and attention_mask is not None: # for batched input, can't accept attention_mask # TODO: this check needs some time