diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 6e5a33a7..bdca4092 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1243,13 +1243,20 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.qwen import qwen_attention_forward + from ipex_llm.transformers.models.qwen import qwen_attention_forward_registered from ipex_llm.transformers.models.qwen import qwen_mlp_forward from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward from ipex_llm.transformers.models.qwen import qwen_model_forward - convert_forward(model, - module.QWenAttention, - qwen_attention_forward - ) + if model.config.max_position_embeddings == 8192: + convert_forward(model, + module.QWenAttention, + qwen_attention_forward_registered + ) + else: + convert_forward(model, + module.QWenAttention, + qwen_attention_forward + ) convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward) @@ -1513,7 +1520,7 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.starcoder2 import model_forward convert_forward(model, module.Starcoder2Attention, attention_forward) convert_forward(model, module.Starcoder2Model, model_forward) - elif model.config.model_type == 'phi': + elif model.config.model_type in ["phi3", "phi3_v"]: # for phi-2 modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) @@ -1521,7 +1528,7 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.phi import model_forward convert_forward(model, module.PhiAttention, attention_forward) convert_forward(model, module.PhiModel, model_forward) - elif model.config.model_type in ["phi3", "phi3_v"]: + elif model.config.model_type == "phi3": # for phi-3 modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen.py b/python/llm/src/ipex_llm/transformers/models/qwen.py index 6aad5208..2856eb1c 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen.py @@ -180,6 +180,132 @@ def qwen_attention_forward( return attn_output, past_key_value +def qwen_attention_forward_registered( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + rotary_pos_emb_list: Optional[List[torch.Tensor]] = None, + registered_causal_mask: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # invalidInputError(not self.use_flash_attn and not self.use_cache_quantization, + # "flash attn and kv_cache quantization are not supported") + bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + past_key_value = (None if layer_past is None + else (layer_past[0].transpose(1, 2), layer_past[1].transpose(1, 2))) + + qkv = self.c_attn(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_heads, + self.num_heads], dim=1) + + kv_seq_len = key_states.shape[2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[2] + + # IPEX-LLM OPT: fuse rope + position_ids = rotary_pos_emb_list[-1] # the last one is posisiton_ids + inv_freq = rotary_pos_emb_list[-2] + rotary_pos_emb_list = rotary_pos_emb_list[:-2] + invalidInputError(len(rotary_pos_emb_list) == 1, + "rotary_pos_emb_list's length cannot be larger than 1") + use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training) + rotary_pos_emb = rotary_pos_emb_list[0] + if use_fuse_rope: + rot_dim = rotary_pos_emb[0].size(-1) + import linear_q4_0 + linear_q4_0.rotary_half_inplaced(inv_freq, position_ids, + query_states[..., :rot_dim], key_states[..., :rot_dim]) + else: + rotary_pos_emb = [i[:, -q_len:, :, :].transpose(1, 2) for i in rotary_pos_emb] + query_states = apply_rotary_pos_emb(query_states, rotary_pos_emb) + key_states = apply_rotary_pos_emb(key_states, rotary_pos_emb) + + if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training: + seq_start = kv_seq_len - q_len + seq_end = kv_seq_len + logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].transpose(1, 2) + query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states) + + # IPEX-LLM OPT: kv cache and quantzie kv cache + use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states) + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, device + ) + past_key_value = (key_states.transpose(1, 2), + value_states.transpose(1, 2)) if use_cache else None + + # IPEX-LLM OPT: sdp + attn_weights = None + if not self.training and not hidden_states.requires_grad and \ + use_flash_attention(query_states, key_states, attention_mask): + attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16), + key_states.to(dtype=torch.float16), + value_states.to(dtype=torch.float16), + is_causal=True).to(hidden_states.dtype) + elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): + import linear_q4_0 + if use_quantize_kv: + attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states) + else: + attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states) + else: + if q_len > 1: + causal_mask = registered_causal_mask[ + :, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len + ] + attention_mask = torch.zeros(causal_mask.shape, dtype=query_states.dtype, + device=query_states.device) + attention_mask.masked_fill_(causal_mask.logical_not(), + torch.finfo(attention_mask.dtype).min) + attention_mask = attention_mask.expand([bsz, -1, -1, -1]) + else: + attention_mask = None + + if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): + import linear_q4_0 + if use_quantize_kv: + attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, + attention_mask) + else: + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, + attention_mask) + else: + if use_quantize_kv: + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + if self.softmax_in_fp32: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to( + value_states.dtype) + else: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.c_proj(attn_output) + + if output_attentions: + return attn_output, past_key_value, attn_weights + else: + return attn_output, past_key_value + + def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: x_2d = x.view(-1, x.shape[-1]) qtype = getattr(self.w1, "qtype", None)