parent
							
								
									0e53f20edb
								
							
						
					
					
						commit
						c5e8b90c8d
					
				
					 2 changed files with 139 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -1243,13 +1243,20 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
            module = importlib.import_module(modeling_module_name)
 | 
			
		||||
            from ipex_llm.transformers.models.qwen import qwen_attention_forward
 | 
			
		||||
            from ipex_llm.transformers.models.qwen import qwen_attention_forward_registered
 | 
			
		||||
            from ipex_llm.transformers.models.qwen import qwen_mlp_forward
 | 
			
		||||
            from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
			
		||||
            from ipex_llm.transformers.models.qwen import qwen_model_forward
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.QWenAttention,
 | 
			
		||||
                            qwen_attention_forward
 | 
			
		||||
                            )
 | 
			
		||||
            if model.config.max_position_embeddings == 8192:
 | 
			
		||||
                convert_forward(model,
 | 
			
		||||
                                module.QWenAttention,
 | 
			
		||||
                                qwen_attention_forward_registered
 | 
			
		||||
                                )
 | 
			
		||||
            else:
 | 
			
		||||
                convert_forward(model,
 | 
			
		||||
                                module.QWenAttention,
 | 
			
		||||
                                qwen_attention_forward
 | 
			
		||||
                                )
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.RMSNorm,
 | 
			
		||||
                            chatglm_rms_norm_forward)
 | 
			
		||||
| 
						 | 
				
			
			@ -1513,7 +1520,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        from ipex_llm.transformers.models.starcoder2 import model_forward
 | 
			
		||||
        convert_forward(model, module.Starcoder2Attention, attention_forward)
 | 
			
		||||
        convert_forward(model, module.Starcoder2Model, model_forward)
 | 
			
		||||
    elif model.config.model_type == 'phi':
 | 
			
		||||
    elif model.config.model_type in ["phi3", "phi3_v"]:
 | 
			
		||||
        # for phi-2
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			@ -1521,7 +1528,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        from ipex_llm.transformers.models.phi import model_forward
 | 
			
		||||
        convert_forward(model, module.PhiAttention, attention_forward)
 | 
			
		||||
        convert_forward(model, module.PhiModel, model_forward)
 | 
			
		||||
    elif model.config.model_type in ["phi3", "phi3_v"]:
 | 
			
		||||
    elif model.config.model_type == "phi3":
 | 
			
		||||
        # for phi-3
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -180,6 +180,132 @@ def qwen_attention_forward(
 | 
			
		|||
        return attn_output, past_key_value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen_attention_forward_registered(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: Optional[Tuple[torch.FloatTensor]],
 | 
			
		||||
    rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
 | 
			
		||||
    registered_causal_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    layer_past: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
    attention_mask: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    head_mask: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    encoder_hidden_states: Optional[torch.Tensor] = None,
 | 
			
		||||
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    output_attentions: Optional[bool] = False,
 | 
			
		||||
    use_cache: Optional[bool] = False,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
    # invalidInputError(not self.use_flash_attn and not self.use_cache_quantization,
 | 
			
		||||
    #                   "flash attn and kv_cache quantization are not supported")
 | 
			
		||||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
    device = hidden_states.device
 | 
			
		||||
    past_key_value = (None if layer_past is None
 | 
			
		||||
                      else (layer_past[0].transpose(1, 2), layer_past[1].transpose(1, 2)))
 | 
			
		||||
 | 
			
		||||
    qkv = self.c_attn(hidden_states)
 | 
			
		||||
    qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
 | 
			
		||||
    qkv = qkv.transpose(1, 2)
 | 
			
		||||
    query_states, key_states, value_states = qkv.split([self.num_heads,
 | 
			
		||||
                                                        self.num_heads,
 | 
			
		||||
                                                        self.num_heads], dim=1)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[2]
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: fuse rope
 | 
			
		||||
    position_ids = rotary_pos_emb_list[-1]  # the last one is posisiton_ids
 | 
			
		||||
    inv_freq = rotary_pos_emb_list[-2]
 | 
			
		||||
    rotary_pos_emb_list = rotary_pos_emb_list[:-2]
 | 
			
		||||
    invalidInputError(len(rotary_pos_emb_list) == 1,
 | 
			
		||||
                      "rotary_pos_emb_list's length cannot be larger than 1")
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
 | 
			
		||||
    rotary_pos_emb = rotary_pos_emb_list[0]
 | 
			
		||||
    if use_fuse_rope:
 | 
			
		||||
        rot_dim = rotary_pos_emb[0].size(-1)
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        linear_q4_0.rotary_half_inplaced(inv_freq, position_ids,
 | 
			
		||||
                                         query_states[..., :rot_dim], key_states[..., :rot_dim])
 | 
			
		||||
    else:
 | 
			
		||||
        rotary_pos_emb = [i[:, -q_len:, :, :].transpose(1, 2) for i in rotary_pos_emb]
 | 
			
		||||
        query_states = apply_rotary_pos_emb(query_states, rotary_pos_emb)
 | 
			
		||||
        key_states = apply_rotary_pos_emb(key_states, rotary_pos_emb)
 | 
			
		||||
 | 
			
		||||
    if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training:
 | 
			
		||||
        seq_start = kv_seq_len - q_len
 | 
			
		||||
        seq_end = kv_seq_len
 | 
			
		||||
        logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].transpose(1, 2)
 | 
			
		||||
        query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantzie kv cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        past_key_value, key_states, value_states,
 | 
			
		||||
        kv_seq_len, use_quantize_kv, device
 | 
			
		||||
    )
 | 
			
		||||
    past_key_value = (key_states.transpose(1, 2),
 | 
			
		||||
                      value_states.transpose(1, 2)) if use_cache else None
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: sdp
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_flash_attention(query_states, key_states, attention_mask):
 | 
			
		||||
        attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
 | 
			
		||||
                                                     key_states.to(dtype=torch.float16),
 | 
			
		||||
                                                     value_states.to(dtype=torch.float16),
 | 
			
		||||
                                                     is_causal=True).to(hidden_states.dtype)
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states)
 | 
			
		||||
    else:
 | 
			
		||||
        if q_len > 1:
 | 
			
		||||
            causal_mask = registered_causal_mask[
 | 
			
		||||
                :, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len
 | 
			
		||||
            ]
 | 
			
		||||
            attention_mask = torch.zeros(causal_mask.shape, dtype=query_states.dtype,
 | 
			
		||||
                                         device=query_states.device)
 | 
			
		||||
            attention_mask.masked_fill_(causal_mask.logical_not(),
 | 
			
		||||
                                        torch.finfo(attention_mask.dtype).min)
 | 
			
		||||
            attention_mask = attention_mask.expand([bsz, -1, -1, -1])
 | 
			
		||||
        else:
 | 
			
		||||
            attention_mask = None
 | 
			
		||||
 | 
			
		||||
        if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
            if use_quantize_kv:
 | 
			
		||||
                attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                                  attention_mask)
 | 
			
		||||
            else:
 | 
			
		||||
                attn_output = linear_q4_0.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                              attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            if use_quantize_kv:
 | 
			
		||||
                key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                                query_states.dtype)
 | 
			
		||||
            attn_weights = torch.matmul(query_states,
 | 
			
		||||
                                        key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attn_weights = attn_weights + attention_mask
 | 
			
		||||
            if self.softmax_in_fp32:
 | 
			
		||||
                attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
 | 
			
		||||
                                                           dtype=torch.float32).to(
 | 
			
		||||
                                                               value_states.dtype)
 | 
			
		||||
            else:
 | 
			
		||||
                attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
 | 
			
		||||
            attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.view(bsz, q_len, self.hidden_size)
 | 
			
		||||
 | 
			
		||||
    attn_output = self.c_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
    if output_attentions:
 | 
			
		||||
        return attn_output, past_key_value, attn_weights
 | 
			
		||||
    else:
 | 
			
		||||
        return attn_output, past_key_value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    x_2d = x.view(-1, x.shape[-1])
 | 
			
		||||
    qtype = getattr(self.w1, "qtype", None)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue