optimize new minicpm model (#12579)
This commit is contained in:
		
							parent
							
								
									4540424271
								
							
						
					
					
						commit
						80f2fdc37b
					
				
					 3 changed files with 15 additions and 62 deletions
				
			
		| 
						 | 
				
			
			@ -217,8 +217,8 @@ def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, de
 | 
			
		|||
    return mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
 | 
			
		||||
                                 mask: torch.Tensor = None,
 | 
			
		||||
def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
 | 
			
		||||
                                 value: torch.Tensor, mask: torch.Tensor = None,
 | 
			
		||||
                                 is_causal: bool = False, scale: float = None) -> torch.Tensor:
 | 
			
		||||
    bsz, n_heads, seq_length, head_dim = query.shape
 | 
			
		||||
    _, n_kv_heads, kv_length, _ = key.shape
 | 
			
		||||
| 
						 | 
				
			
			@ -268,7 +268,7 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value:
 | 
			
		|||
                attn_output = xe_addons.sdp(query, key, value, mask)
 | 
			
		||||
        else:
 | 
			
		||||
            if key.dtype == torch.uint8:
 | 
			
		||||
                attn_output = xe_addons.sdp_fp8(query, key, value, mask)
 | 
			
		||||
                attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask)
 | 
			
		||||
            else:
 | 
			
		||||
                attn_output = xe_addons.sdp_non_causal(query, key, value, mask)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -281,6 +281,8 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value:
 | 
			
		|||
            key = repeat_kv(key, n_heads // n_kv_heads)
 | 
			
		||||
            value = repeat_kv(value, n_heads // n_kv_heads)
 | 
			
		||||
 | 
			
		||||
        return torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
        attn_output = torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
            query, key, value, mask, is_causal=is_causal, scale=scale
 | 
			
		||||
        )
 | 
			
		||||
        attn_output = attn_output.to(dtype)    # workaround ipex 2.1's bug
 | 
			
		||||
        return attn_output
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -127,49 +127,12 @@ def minicpm_attention_forward(
 | 
			
		|||
            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                             self.layer_idx, None)
 | 
			
		||||
 | 
			
		||||
    from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        # [CompressKV]
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
 | 
			
		||||
 | 
			
		||||
        if use_quantizekv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                            attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                        attention_mask)
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_quantizekv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
			
		||||
                                                   value_states, attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
			
		||||
                                               value_states, attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        if use_quantizekv:
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
        attn_weights = torch.matmul(
 | 
			
		||||
            query_states, key_states.transpose(2, 3)
 | 
			
		||||
        ) / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = nn.functional.softmax(
 | 
			
		||||
            attn_weights, dim=-1, dtype=torch.float32
 | 
			
		||||
        ).to(query_states.dtype)
 | 
			
		||||
        attn_weights = nn.functional.dropout(
 | 
			
		||||
            attn_weights, p=self.attention_dropout, training=self.training
 | 
			
		||||
        )
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
    attn_output = scaled_dot_product_attention(
 | 
			
		||||
        query_states, key_states, value_states,
 | 
			
		||||
        attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,7 +28,6 @@ from typing import Optional, List
 | 
			
		|||
from torch.nn.functional import linear
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd
 | 
			
		||||
from ipex_llm.transformers.models.common import attention_softmax
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
 | 
			
		||||
from transformers import AutoProcessor, TextIteratorStreamer
 | 
			
		||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -73,21 +72,10 @@ def siglip_attention_forward(
 | 
			
		|||
            72, 80
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
 | 
			
		||||
            import xe_addons
 | 
			
		||||
            attn_weights = None
 | 
			
		||||
            attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
 | 
			
		||||
                                                   value_states.contiguous(), attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
            attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
 | 
			
		||||
            attn_weights = torch.nn.functional.dropout(attn_weights,
 | 
			
		||||
                                                       p=self.dropout, training=self.training)
 | 
			
		||||
            attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
        from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
        attn_output = scaled_dot_product_attention(query_states, key_states, value_states,
 | 
			
		||||
                                                   attention_mask, False, math.sqrt(self.head_dim))
 | 
			
		||||
 | 
			
		||||
        attn_output = attn_output[:, :, :, :self.head_dim]
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue