LLM: fix llama2 FP16 & bs>1 & autotp on PVC and ARC (#10611)
This commit is contained in:
		
							parent
							
								
									654dc5ba57
								
							
						
					
					
						commit
						2bbd8a1548
					
				
					 2 changed files with 17 additions and 12 deletions
				
			
		| 
						 | 
					@ -533,12 +533,12 @@ def llama_attention_forward_4_31_original(
 | 
				
			||||||
                    self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
 | 
					                    self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
 | 
				
			||||||
                    self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
 | 
					                    self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
 | 
				
			||||||
                    torch.xpu.empty_cache()
 | 
					                    torch.xpu.empty_cache()
 | 
				
			||||||
                query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
 | 
					                query_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
 | 
				
			||||||
                                           device=hidden_states.device)
 | 
					                                           dtype=hidden_states.dtype, device=hidden_states.device)
 | 
				
			||||||
                key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
 | 
					                key_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
 | 
				
			||||||
                                         device=hidden_states.device)
 | 
					                                         dtype=hidden_states.dtype, device=hidden_states.device)
 | 
				
			||||||
                value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
 | 
					                value_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
 | 
				
			||||||
                                           device=hidden_states.device)
 | 
					                                           dtype=hidden_states.dtype, device=hidden_states.device)
 | 
				
			||||||
                torch.ops.torch_ipex.mm_qkv_out(
 | 
					                torch.ops.torch_ipex.mm_qkv_out(
 | 
				
			||||||
                    hidden_states, self.qkv_proj_weight, None,
 | 
					                    hidden_states, self.qkv_proj_weight, None,
 | 
				
			||||||
                    query_states, key_states, value_states
 | 
					                    query_states, key_states, value_states
 | 
				
			||||||
| 
						 | 
					@ -1165,12 +1165,12 @@ def llama_attention_forward_4_36_original(
 | 
				
			||||||
                    self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
 | 
					                    self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
 | 
				
			||||||
                    self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
 | 
					                    self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
 | 
				
			||||||
                    torch.xpu.empty_cache()
 | 
					                    torch.xpu.empty_cache()
 | 
				
			||||||
                query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
 | 
					                query_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
 | 
				
			||||||
                                           device=hidden_states.device)
 | 
					                                           dtype=hidden_states.dtype, device=hidden_states.device)
 | 
				
			||||||
                key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
 | 
					                key_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
 | 
				
			||||||
                                         device=hidden_states.device)
 | 
					                                         dtype=hidden_states.dtype, device=hidden_states.device)
 | 
				
			||||||
                value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype,
 | 
					                value_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
 | 
				
			||||||
                                           device=hidden_states.device)
 | 
					                                           dtype=hidden_states.dtype, device=hidden_states.device)
 | 
				
			||||||
                torch.ops.torch_ipex.mm_qkv_out(
 | 
					                torch.ops.torch_ipex.mm_qkv_out(
 | 
				
			||||||
                    hidden_states, self.qkv_proj_weight, None,
 | 
					                    hidden_states, self.qkv_proj_weight, None,
 | 
				
			||||||
                    query_states, key_states, value_states
 | 
					                    query_states, key_states, value_states
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -21,6 +21,7 @@ from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
					from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type
 | 
					from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type
 | 
				
			||||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4
 | 
					from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4
 | 
				
			||||||
 | 
					from ipex_llm.transformers.convert import is_deepspeed_available
 | 
				
			||||||
 | 
					
 | 
				
			||||||
FP8_KV_ALLOC_LENGTH = 512
 | 
					FP8_KV_ALLOC_LENGTH = 512
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -341,6 +342,10 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
 | 
				
			||||||
    if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"):
 | 
					    if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"):
 | 
				
			||||||
        # esimd_sdp not support PVC GPU when batch size > 1 for now
 | 
					        # esimd_sdp not support PVC GPU when batch size > 1 for now
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
 | 
					    if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \
 | 
				
			||||||
 | 
					            and is_deepspeed_available:
 | 
				
			||||||
 | 
					        # esimd_sdp not support ARC GPU when batch size > 1 using DeepSpeed AutoTP for now
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
    if query_states.shape[0] > 1 and attention_mask is not None:
 | 
					    if query_states.shape[0] > 1 and attention_mask is not None:
 | 
				
			||||||
        # for batched input, can't accept attention_mask
 | 
					        # for batched input, can't accept attention_mask
 | 
				
			||||||
        # TODO: this check needs some time
 | 
					        # TODO: this check needs some time
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue