refactor baichuan2-7b (#11062)
This commit is contained in:
		
							parent
							
								
									84239d0bd3
								
							
						
					
					
						commit
						981d668be6
					
				
					 2 changed files with 139 additions and 287 deletions
				
			
		| 
						 | 
				
			
			@ -710,6 +710,12 @@ def _optimize_pre(model):
 | 
			
		|||
        model.apply(pre_compute_inv_freq)
 | 
			
		||||
        from ipex_llm.transformers.models.phi3 import split_mlp
 | 
			
		||||
        model.apply(split_mlp)
 | 
			
		||||
    # for baichuan2
 | 
			
		||||
    if model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
 | 
			
		||||
        if model.config.hidden_size in [4096, 2048]:
 | 
			
		||||
            # baichuan2-7B
 | 
			
		||||
            from ipex_llm.transformers.models.baichuan2 import pre_compute_inv_freq
 | 
			
		||||
            model.apply(pre_compute_inv_freq)
 | 
			
		||||
    if model.config.model_type == "qwen":
 | 
			
		||||
        rope_base = model.config.rotary_emb_base
 | 
			
		||||
        from accelerate.big_modeling import init_empty_weights
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -23,50 +23,30 @@ from typing import Optional, Tuple
 | 
			
		|||
import torch
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from torch.nn import functional as F
 | 
			
		||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
 | 
			
		||||
    restore_fp8_kv_cache, use_quantize_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
			
		||||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from xformers import ops as xops
 | 
			
		||||
except ImportError:
 | 
			
		||||
    xops = None
 | 
			
		||||
    logger.warning(
 | 
			
		||||
        "Xformers is not installed correctly. If you want to use memory_efficient_attention to "
 | 
			
		||||
        "accelerate training use the following command to install Xformers\npip install xformers."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, output_attentions):
 | 
			
		||||
    if not output_attentions:
 | 
			
		||||
        if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
 | 
			
		||||
            return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
 | 
			
		||||
        elif query_states.dtype == torch.float16 and \
 | 
			
		||||
                query_states.shape[2] >= 5400:
 | 
			
		||||
            # split tensor for memory block limitation
 | 
			
		||||
            # support fp16 and set input length threshold at 5400 for now
 | 
			
		||||
            return True
 | 
			
		||||
        elif query_states.element_size()*bsz*num_heads*q_len*kv_seq_len >= 4*1024**3:
 | 
			
		||||
            # attn_weight size larger than memory block limitation 4GB
 | 
			
		||||
            return True
 | 
			
		||||
    return False
 | 
			
		||||
def pre_compute_inv_freq(module: torch.nn.Module):
 | 
			
		||||
    if module.__class__.__name__ == "RotaryEmbedding":
 | 
			
		||||
        inv_freq = module.inv_freq
 | 
			
		||||
        del module.inv_freq
 | 
			
		||||
        module.register_buffer("inv_freq", inv_freq, persistent=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_13b_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training or hidden_states.requires_grad):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
			
		||||
        output = linear_q4_0.rms_norm(self.weight, x_2d, self.epsilon)
 | 
			
		||||
| 
						 | 
				
			
			@ -105,95 +85,117 @@ def baichuan_attention_forward_7b(
 | 
			
		|||
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
    output_attentions: bool = False,
 | 
			
		||||
    use_cache: bool = False,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
    if use_quantize_kv_cache(self.W_pack, hidden_states):
 | 
			
		||||
        forward_function = baichuan_attention_forward_7b_quantized
 | 
			
		||||
    else:
 | 
			
		||||
        forward_function = baichuan_attention_forward_7b_origin
 | 
			
		||||
    return forward_function(
 | 
			
		||||
        self=self,
 | 
			
		||||
        hidden_states=hidden_states,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
        position_ids=position_ids,
 | 
			
		||||
        past_key_value=past_key_value,
 | 
			
		||||
        output_attentions=output_attentions,
 | 
			
		||||
        use_cache=use_cache
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_attention_forward_7b_quantized(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
    output_attentions: bool = False,
 | 
			
		||||
    use_cache: bool = False,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
):
 | 
			
		||||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
    device = hidden_states.device
 | 
			
		||||
 | 
			
		||||
    proj = self.W_pack(hidden_states)
 | 
			
		||||
    proj = torch.chunk(proj, 3, -1)
 | 
			
		||||
    # batch_size x source_len x hidden_size
 | 
			
		||||
    query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    # batch_size x target_len x head_size
 | 
			
		||||
    key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    # batch_size x source_len x hidden_size
 | 
			
		||||
    value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    qkv = self.W_pack(hidden_states)
 | 
			
		||||
    qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
 | 
			
		||||
    qkv = qkv.transpose(1, 2)
 | 
			
		||||
    query_states, key_states, value_states = qkv.split([self.num_heads,
 | 
			
		||||
                                                        self.num_heads,
 | 
			
		||||
                                                        self.num_heads], dim=1)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    kv_seq_len = key_states.shape[2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
    if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                     key_states,
 | 
			
		||||
                                                                     position_ids,
 | 
			
		||||
                                                                     "baichuan")
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[2]
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: fuse rope
 | 
			
		||||
    if should_use_fuse_rope(hidden_states, position_ids, self.training):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        linear_q4_0.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
			
		||||
                                         query_states, key_states)
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
                                                        cos, sin, position_ids, "baichuan")
 | 
			
		||||
    if past_key_value is None:
 | 
			
		||||
        kv_seq_len = key_states.shape[-2]
 | 
			
		||||
        k_cache, v_cache = init_fp8_kv_cache(
 | 
			
		||||
            bsz, self.num_heads, kv_seq_len, self.head_dim,
 | 
			
		||||
            device=device
 | 
			
		||||
        )
 | 
			
		||||
        query_states = query_states.to(hidden_states.dtype)
 | 
			
		||||
        key_states = key_states.to(hidden_states.dtype)
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantize kv
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
 | 
			
		||||
    if use_quantize_kv:
 | 
			
		||||
        if past_key_value is None:
 | 
			
		||||
            k_cache, v_cache = init_fp8_kv_cache(
 | 
			
		||||
                bsz, self.num_heads, kv_seq_len, self.head_dim,
 | 
			
		||||
                device=device
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            k_cache, v_cache = past_key_value
 | 
			
		||||
        key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
 | 
			
		||||
                                                       key_states, value_states)
 | 
			
		||||
    else:
 | 
			
		||||
        k_cache, v_cache = past_key_value
 | 
			
		||||
    key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
 | 
			
		||||
                                                   key_states, value_states)
 | 
			
		||||
        if past_key_value is None:
 | 
			
		||||
            max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
            k_cache, v_cache = init_kv_cache(bsz,
 | 
			
		||||
                                             self.num_heads,
 | 
			
		||||
                                             self.head_dim,
 | 
			
		||||
                                             kv_seq_len,
 | 
			
		||||
                                             max_cache_length,
 | 
			
		||||
                                             dtype=key_states.dtype,
 | 
			
		||||
                                             device=device)
 | 
			
		||||
            k_cache[...] = key_states
 | 
			
		||||
            v_cache[...] = value_states
 | 
			
		||||
            key_states = k_cache
 | 
			
		||||
            value_states = v_cache
 | 
			
		||||
        else:
 | 
			
		||||
            k_cache, v_cache = past_key_value
 | 
			
		||||
            if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
 | 
			
		||||
                max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
                new_k_cache, new_v_cache = extend_kv_cache(bsz,
 | 
			
		||||
                                                           self.num_heads,
 | 
			
		||||
                                                           self.head_dim,
 | 
			
		||||
                                                           k_cache.size(2),
 | 
			
		||||
                                                           max_cache_length,
 | 
			
		||||
                                                           dtype=k_cache.dtype,
 | 
			
		||||
                                                           device=device)
 | 
			
		||||
                new_k_cache[...] = k_cache
 | 
			
		||||
                new_v_cache[...] = v_cache
 | 
			
		||||
                k_cache = new_k_cache
 | 
			
		||||
                v_cache = new_v_cache
 | 
			
		||||
            key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
			
		||||
 | 
			
		||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
			
		||||
 | 
			
		||||
    invalidInputError(attention_mask is None or attention_mask.dtype != torch.bool,
 | 
			
		||||
                      "attention_mask's dtype cannot be bool")
 | 
			
		||||
    if self.training:
 | 
			
		||||
        warnings.warn("xops is not supported on Intel GPU, so just use normal implementation")
 | 
			
		||||
 | 
			
		||||
    scaling_factor = 1 / math.sqrt(query_states.size(-1))
 | 
			
		||||
    if query_states.size(2) != 1 or device.type != 'xpu':
 | 
			
		||||
        key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                        query_states.dtype)
 | 
			
		||||
        if should_split_qkv_tensor(query_states, bsz, self.num_heads,
 | 
			
		||||
                                   q_len, kv_seq_len, output_attentions):
 | 
			
		||||
            attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states, key_states,
 | 
			
		||||
                                                                    value_states, attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1))
 | 
			
		||||
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attn_output += attention_mask
 | 
			
		||||
            attn_output = torch.softmax(attn_output, -1)
 | 
			
		||||
            attn_output = attn_output.to(hidden_states.dtype)
 | 
			
		||||
            attn_output = torch.matmul(attn_output, value_states)
 | 
			
		||||
    else:
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_flash_attention(query_states, key_states, attention_mask):
 | 
			
		||||
        attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
 | 
			
		||||
                                                     key_states.to(dtype=torch.float16),
 | 
			
		||||
                                                     value_states.to(dtype=torch.float16),
 | 
			
		||||
                                                     is_causal=True).to(hidden_states.dtype)
 | 
			
		||||
    elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                          attention_mask)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                              attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states)
 | 
			
		||||
    else:
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
        attn_weights = torch.matmul(query_states,
 | 
			
		||||
                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
 | 
			
		||||
                                                   dtype=torch.float32).to(value_states.dtype)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
 | 
			
		||||
    attn_output = self.o_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
    if not output_attentions:
 | 
			
		||||
| 
						 | 
				
			
			@ -202,134 +204,6 @@ def baichuan_attention_forward_7b_quantized(
 | 
			
		|||
    return attn_output, attn_weights, past_key_value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_attention_forward_7b_origin(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
    output_attentions: bool = False,
 | 
			
		||||
    use_cache: bool = False,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
    device = hidden_states.device
 | 
			
		||||
 | 
			
		||||
    proj = self.W_pack(hidden_states)
 | 
			
		||||
    proj = torch.chunk(proj, 3, -1)
 | 
			
		||||
    # batch_size x source_len x hidden_size
 | 
			
		||||
    query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    # batch_size x target_len x head_size
 | 
			
		||||
    key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    # batch_size x source_len x hidden_size
 | 
			
		||||
    value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    enough_kv_room = True
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
    if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                     key_states,
 | 
			
		||||
                                                                     position_ids,
 | 
			
		||||
                                                                     "baichuan")
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
                                                        cos, sin, position_ids, "baichuan")
 | 
			
		||||
    # [bsz, nh, t, hd]
 | 
			
		||||
 | 
			
		||||
    # if past_key_value is not None:
 | 
			
		||||
    #     # reuse k, v, self_attention
 | 
			
		||||
    #     key_states = torch.cat([past_key_value[0], key_states], dim=2)
 | 
			
		||||
    #     value_states = torch.cat([past_key_value[1], value_states], dim=2)
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        # reuse k, v, self_attention
 | 
			
		||||
        cache_k = past_key_value[0]
 | 
			
		||||
        cache_v = past_key_value[1]
 | 
			
		||||
        if not enough_kv_room:
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
			
		||||
                                                       self.num_heads,
 | 
			
		||||
                                                       self.head_dim,
 | 
			
		||||
                                                       cache_k.size(2),
 | 
			
		||||
                                                       kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
			
		||||
                                                       dtype=cache_k.dtype,
 | 
			
		||||
                                                       device=device)
 | 
			
		||||
            new_cache_k[:] = cache_k
 | 
			
		||||
            new_cache_v[:] = cache_v
 | 
			
		||||
            cache_k = new_cache_k
 | 
			
		||||
            cache_v = new_cache_v
 | 
			
		||||
 | 
			
		||||
        key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
 | 
			
		||||
 | 
			
		||||
    elif use_cache:
 | 
			
		||||
        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
        new_key_states, new_value_states = init_kv_cache(bsz,
 | 
			
		||||
                                                         self.num_heads,
 | 
			
		||||
                                                         self.head_dim,
 | 
			
		||||
                                                         kv_seq_len,
 | 
			
		||||
                                                         max_cache_length,
 | 
			
		||||
                                                         dtype=key_states.dtype,
 | 
			
		||||
                                                         device=device)
 | 
			
		||||
        new_key_states[:] = key_states
 | 
			
		||||
        new_value_states[:] = value_states
 | 
			
		||||
        key_states = new_key_states
 | 
			
		||||
        value_states = new_value_states
 | 
			
		||||
 | 
			
		||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
			
		||||
 | 
			
		||||
    invalidInputError(attention_mask is None or attention_mask.dtype != torch.bool,
 | 
			
		||||
                      "attention_mask's dtype cannot be bool")
 | 
			
		||||
 | 
			
		||||
    if xops is not None and self.training:
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
        query_states = query_states.transpose(1, 2)
 | 
			
		||||
        key_states = key_states.transpose(1, 2)
 | 
			
		||||
        value_states = value_states.transpose(1, 2)
 | 
			
		||||
        attn_output = xops.memory_efficient_attention(
 | 
			
		||||
            query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        if not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
                use_flash_attention(query_states, key_states, attention_mask):
 | 
			
		||||
            attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
 | 
			
		||||
                                                         key_states.to(dtype=torch.float16),
 | 
			
		||||
                                                         value_states.to(dtype=torch.float16),
 | 
			
		||||
                                                         is_causal=True)
 | 
			
		||||
            attn_weights = None
 | 
			
		||||
        elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
                use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
            attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
            attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
            attn_weights = None
 | 
			
		||||
        else:
 | 
			
		||||
            if should_split_qkv_tensor(query_states, bsz, self.num_heads,
 | 
			
		||||
                                       q_len, kv_seq_len, output_attentions):
 | 
			
		||||
                attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states,
 | 
			
		||||
                                                                        key_states,
 | 
			
		||||
                                                                        value_states,
 | 
			
		||||
                                                                        attention_mask)
 | 
			
		||||
            else:
 | 
			
		||||
                scaling_factor = 1 / math.sqrt(query_states.size(-1))
 | 
			
		||||
                attn_output = torch.matmul(query_states * scaling_factor,
 | 
			
		||||
                                           key_states.transpose(-2, -1))
 | 
			
		||||
                if attention_mask is not None:
 | 
			
		||||
                    attn_output += attention_mask
 | 
			
		||||
                attn_output = torch.softmax(attn_output, -1)
 | 
			
		||||
                attn_output = torch.matmul(attn_output, value_states)
 | 
			
		||||
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
 | 
			
		||||
    attn_output = self.o_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
    if not output_attentions:
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
 | 
			
		||||
    return attn_output.to(hidden_states.dtype), attn_weights, past_key_value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_attention_forward_13b(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			@ -507,48 +381,38 @@ def baichuan_attention_forward_13b_origin(
 | 
			
		|||
        value_states = new_value_states
 | 
			
		||||
 | 
			
		||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
			
		||||
    if xops is not None and self.training:
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
        # query_states = query_states.transpose(1, 2)
 | 
			
		||||
        # key_states = key_states.transpose(1, 2)
 | 
			
		||||
        # value_states = value_states.transpose(1, 2)
 | 
			
		||||
        # attn_output = xops.memory_efficient_attention(
 | 
			
		||||
        #     query_states, key_states, value_states, attn_bias=attention_mask
 | 
			
		||||
        # )
 | 
			
		||||
        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True,
 | 
			
		||||
                                            enable_mem_efficient=True):
 | 
			
		||||
            attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states,
 | 
			
		||||
                                                         attn_mask=attention_mask)
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
    else:
 | 
			
		||||
        attn_weights = torch.matmul(
 | 
			
		||||
            query_states.to(dtype=key_states.dtype), key_states.transpose(2, 3)
 | 
			
		||||
        ) / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            if q_len == 1:  # inference with cache
 | 
			
		||||
                if len(attention_mask.size()) == 4:
 | 
			
		||||
                    attention_mask = attention_mask[:, :, -1:, :]
 | 
			
		||||
                else:
 | 
			
		||||
                    attention_mask = attention_mask[:, -1:, :]
 | 
			
		||||
            if attention_mask.shape[-2] == attn_weights.shape[-2]:
 | 
			
		||||
                attn_weights = attn_weights + attention_mask
 | 
			
		||||
    if self.training:
 | 
			
		||||
        warnings.warn("xops is not supported on Intel GPU, so just use normal implementation")
 | 
			
		||||
 | 
			
		||||
    attn_weights = torch.matmul(
 | 
			
		||||
        query_states.to(dtype=key_states.dtype), key_states.transpose(2, 3)
 | 
			
		||||
    ) / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
    if attention_mask is not None:
 | 
			
		||||
        if q_len == 1:  # inference with cache
 | 
			
		||||
            if len(attention_mask.size()) == 4:
 | 
			
		||||
                attention_mask = attention_mask[:, :, -1:, :]
 | 
			
		||||
            else:
 | 
			
		||||
                # support for Baichuan/Baichuan2 13B Chat running speculative decoding
 | 
			
		||||
                # split attention mask on dim -2
 | 
			
		||||
                split_sizes = [attention_mask.shape[-2] - attn_weights.shape[-2],
 | 
			
		||||
                               attn_weights.shape[-2]]
 | 
			
		||||
                # the last chunk of splited is the new attention mask
 | 
			
		||||
                attention_mask = attention_mask.split(split_sizes, dim=-2)[-1]
 | 
			
		||||
                attn_weights = attn_weights + attention_mask
 | 
			
		||||
            attn_weights = torch.max(
 | 
			
		||||
                attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
 | 
			
		||||
            )
 | 
			
		||||
                attention_mask = attention_mask[:, -1:, :]
 | 
			
		||||
        if attention_mask.shape[-2] == attn_weights.shape[-2]:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
        else:
 | 
			
		||||
            # support for Baichuan/Baichuan2 13B Chat running speculative decoding
 | 
			
		||||
            # split attention mask on dim -2
 | 
			
		||||
            split_sizes = [attention_mask.shape[-2] - attn_weights.shape[-2],
 | 
			
		||||
                           attn_weights.shape[-2]]
 | 
			
		||||
            # the last chunk of splited is the new attention mask
 | 
			
		||||
            attention_mask = attention_mask.split(split_sizes, dim=-2)[-1]
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
        attn_weights = torch.max(
 | 
			
		||||
            attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights.to(dtype=value_states.dtype), value_states)
 | 
			
		||||
    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
 | 
			
		||||
    attn_output = torch.matmul(attn_weights.to(dtype=value_states.dtype), value_states)
 | 
			
		||||
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
    attn_output = self.o_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -647,21 +511,3 @@ def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past):
 | 
			
		|||
            : self.n_head, :seq_length_with_past, :seq_length_with_past
 | 
			
		||||
        ]
 | 
			
		||||
    return mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def native_sdp_split_qkv_tensor(query, key, value, attention_mask):
 | 
			
		||||
    block_size = 8
 | 
			
		||||
    query_split = torch.split(query, block_size, dim=1)
 | 
			
		||||
    key_split = torch.split(key.transpose(-2, -1), block_size, dim=1)
 | 
			
		||||
    value_split = torch.split(value, block_size, dim=1)
 | 
			
		||||
    attn_outputs = []
 | 
			
		||||
    scaling_factor = 1 / math.sqrt(query.size(-1))
 | 
			
		||||
    for q, k, v in zip(query_split, key_split, value_split):
 | 
			
		||||
        attn_output_split = torch.matmul(q * scaling_factor, k)
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_output_split += attention_mask
 | 
			
		||||
        attn_output_split = torch.softmax(attn_output_split, -1)
 | 
			
		||||
        attn_output_split = torch.matmul(attn_output_split, v)
 | 
			
		||||
        attn_outputs.append(attn_output_split)
 | 
			
		||||
    attn_output = torch.cat(attn_outputs, dim=1)
 | 
			
		||||
    return attn_output, None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue