refactor baichuan, glm4 and minicpm3 (#12600)
This commit is contained in:
		
							parent
							
								
									c410d9cf73
								
							
						
					
					
						commit
						7aaf02f602
					
				
					 4 changed files with 32 additions and 167 deletions
				
			
		| 
						 | 
				
			
			@ -24,16 +24,16 @@ import torch
 | 
			
		|||
import torch.utils.checkpoint
 | 
			
		||||
from torch.nn import functional as F
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
 | 
			
		||||
    should_use_compresskv, get_compresskv_attn_mask
 | 
			
		||||
    should_use_compresskv
 | 
			
		||||
from ipex_llm.transformers.models.utils import update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
 | 
			
		||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
 | 
			
		||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicCompressFp8Cache, DynamicCompressCache
 | 
			
		||||
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
 | 
			
		||||
import warnings
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -301,42 +301,16 @@ def baichuan_attention_forward_7b(
 | 
			
		|||
 | 
			
		||||
    # IPEX-LLM OPT: sdp
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_flash_attention(query_states, key_states, attention_mask):
 | 
			
		||||
    if use_flash_attention(query_states, key_states, attention_mask):
 | 
			
		||||
        attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
 | 
			
		||||
                                                     key_states.to(dtype=torch.float16),
 | 
			
		||||
                                                     value_states.to(dtype=torch.float16),
 | 
			
		||||
                                                     is_causal=True).to(hidden_states.dtype)
 | 
			
		||||
    elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                            attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                        attention_mask)
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
			
		||||
                                                   value_states, attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
			
		||||
                                               value_states, attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
        attn_weights = torch.matmul(query_states,
 | 
			
		||||
                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
 | 
			
		||||
                                                   dtype=torch.float32).to(value_states.dtype)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
        attn_output = scaled_dot_product_attention(
 | 
			
		||||
            query_states, key_states, value_states,
 | 
			
		||||
            attention_mask, q_len == kv_seq_len
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,15 +20,14 @@
 | 
			
		|||
import os
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Optional, Tuple, Union
 | 
			
		||||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
 | 
			
		||||
    use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36, \
 | 
			
		||||
    get_compresskv_attn_mask
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from ipex_llm.transformers.models.utils import update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.models.chatglm2 import repeat_kv
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -241,49 +240,10 @@ def chatglm4_attention_forward(
 | 
			
		|||
            past_key_value = None
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: sdp
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if use_sdp(q_len, kv_seq_len, head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states,
 | 
			
		||||
                                                   attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp_causal(query_states, key_states, value_states,
 | 
			
		||||
                                               attention_mask)
 | 
			
		||||
    elif query_states.device.type == "cpu":
 | 
			
		||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
        key_states = repeat_kv(key_states, n_head // n_kv_head)
 | 
			
		||||
        value_states = repeat_kv(value_states, n_head // n_kv_head)
 | 
			
		||||
        if q_len == kv_seq_len:
 | 
			
		||||
            attn_output = torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
                query_states, key_states, value_states, is_causal=True
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
                query_states, key_states, value_states, attention_mask
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
        key_states = repeat_kv(key_states, n_head // n_kv_head)
 | 
			
		||||
        value_states = repeat_kv(value_states, n_head // n_kv_head)
 | 
			
		||||
        attn_weights = torch.matmul(query_states / math.sqrt(head_dim),
 | 
			
		||||
                                    key_states.transpose(2, 3))
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
 | 
			
		||||
                                                   dtype=torch.float32).to(value_states.dtype)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
    attn_output = scaled_dot_product_attention(
 | 
			
		||||
        query_states, key_states, value_states,
 | 
			
		||||
        attention_mask, q_len == kv_seq_len
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # context_layer's shape: [bsz, n_head, seq_len, head_dim] -> [seq_len, bsz, n_head * head_dim]
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, n_head * head_dim)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,10 +20,10 @@
 | 
			
		|||
import torch
 | 
			
		||||
from typing import Optional, Tuple, Union
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base
 | 
			
		||||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from ipex_llm.transformers.models.utils import update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.models.chatglm2 import repeat_kv
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
import math
 | 
			
		||||
| 
						 | 
				
			
			@ -246,53 +246,10 @@ def chatglm4v_attention_forward(
 | 
			
		|||
        past_key_value = None
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: sdp
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if use_sdp(q_len, kv_seq_len, head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states,
 | 
			
		||||
                                                   attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp_causal(query_states, key_states, value_states,
 | 
			
		||||
                                               attention_mask)
 | 
			
		||||
    elif query_states.device.type == "cpu":
 | 
			
		||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
        key_states = repeat_kv(key_states, n_head // n_kv_head)
 | 
			
		||||
        value_states = repeat_kv(value_states, n_head // n_kv_head)
 | 
			
		||||
        if q_len == kv_seq_len:
 | 
			
		||||
            attn_output = torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
                query_states, key_states, value_states, is_causal=True
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
                query_states, key_states, value_states, attention_mask
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
        key_states = repeat_kv(key_states, n_head // n_kv_head)
 | 
			
		||||
        value_states = repeat_kv(value_states, n_head // n_kv_head)
 | 
			
		||||
        attn_weights = torch.matmul(query_states / math.sqrt(head_dim),
 | 
			
		||||
                                    key_states.transpose(2, 3))
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
        if kv_seq_len >= 2048 or bsz >= 64:
 | 
			
		||||
            # for memory considerations, do not upcast attention to fp32
 | 
			
		||||
            # for long sequences or large batches
 | 
			
		||||
            attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
 | 
			
		||||
        else:
 | 
			
		||||
            # upcast attention to fp32
 | 
			
		||||
            attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
 | 
			
		||||
                                                       dtype=torch.float32).to(value_states.dtype)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
    attn_output = scaled_dot_product_attention(
 | 
			
		||||
        query_states, key_states, value_states,
 | 
			
		||||
        attention_mask, q_len == kv_seq_len
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # context_layer's shape: [bsz, n_head, seq_len, head_dim] -> [seq_len, bsz, n_head * head_dim]
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, n_head * head_dim)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,10 +6,10 @@ from typing import Optional, Tuple, List
 | 
			
		|||
from transformers.cache_utils import Cache
 | 
			
		||||
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.utils import rotate_half
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -25,7 +25,7 @@ def pre_compute_inv_freq(module: torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
def padding_v_head_dim(module: torch.nn.Module):
 | 
			
		||||
    if module.__class__.__name__ == "MiniCPMAttention":
 | 
			
		||||
        k_head_dim = module.qk_rope_head_dim + module.qk_nope_head_dim
 | 
			
		||||
        k_head_dim = module.q_head_dim
 | 
			
		||||
        v_head_dim = module.v_head_dim
 | 
			
		||||
        invalidInputError(k_head_dim >= v_head_dim,
 | 
			
		||||
                          f"unsupported k_head_dim and v_head_dim: {k_head_dim} {v_head_dim}")
 | 
			
		||||
| 
						 | 
				
			
			@ -183,37 +183,11 @@ def minicpm3_attention_forward(
 | 
			
		|||
                                                         self.layer_idx, None)
 | 
			
		||||
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if use_sdp(q_len, kv_seq_len, self.q_head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                            attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                        attention_mask)
 | 
			
		||||
        attn_output = attn_output[:, :, :, :self.v_head_dim]
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.q_head_dim, query_states, False):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
			
		||||
                                                   value_states, attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
			
		||||
                                               value_states, attention_mask)
 | 
			
		||||
        attn_output = attn_output[:, :, :, :self.v_head_dim]
 | 
			
		||||
    else:
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = nn.functional.softmax(attn_weights,
 | 
			
		||||
                                             dim=-1, dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states[:, :, :, :self.v_head_dim])
 | 
			
		||||
    attn_output = scaled_dot_product_attention(
 | 
			
		||||
        query_states, key_states, value_states,
 | 
			
		||||
        attention_mask, q_len == kv_seq_len, self.softmax_scale
 | 
			
		||||
    )
 | 
			
		||||
    attn_output = attn_output[:, :, :, :self.v_head_dim]
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue