parent
							
								
									151fcf37bb
								
							
						
					
					
						commit
						dbc3c2d72d
					
				
					 1 changed files with 59 additions and 28 deletions
				
			
		| 
						 | 
				
			
			@ -22,6 +22,9 @@ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
 | 
			
		|||
import torch.nn.functional as F
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, apply_ipex_rotate_every_two
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp
 | 
			
		||||
from ipex_llm.transformers.models.chatglm2 import should_split_qkv_tensor
 | 
			
		||||
from ipex_llm.transformers.models.chatglm2 import split_tensor_along_last_dim
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -31,32 +34,6 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH",
 | 
			
		|||
KV_CACHE_ALLOC_MIN_LENGTH = 512
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_tensor_along_last_dim(
 | 
			
		||||
        tensor: torch.Tensor,
 | 
			
		||||
        num_partitions: int,
 | 
			
		||||
        contiguous_split_chunks: bool = False,
 | 
			
		||||
) -> List[torch.Tensor]:
 | 
			
		||||
    """Split a tensor along its last dimension.
 | 
			
		||||
    Arguments:
 | 
			
		||||
        tensor: input tensor.
 | 
			
		||||
        num_partitions: number of partitions to split the tensor
 | 
			
		||||
        contiguous_split_chunks: If True, make each chunk contiguous
 | 
			
		||||
                                 in memory.
 | 
			
		||||
    Returns:
 | 
			
		||||
        A list of Tensors
 | 
			
		||||
    """
 | 
			
		||||
    # Get the size and dimension.
 | 
			
		||||
    last_dim = tensor.dim() - 1
 | 
			
		||||
    last_dim_size = tensor.size()[last_dim] // num_partitions
 | 
			
		||||
    # Split.
 | 
			
		||||
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
 | 
			
		||||
    # Note: torch.split does not create contiguous tensors by default.
 | 
			
		||||
    if contiguous_split_chunks:
 | 
			
		||||
        return tuple(chunk.contiguous() for chunk in tensor_list)
 | 
			
		||||
 | 
			
		||||
    return tensor_list
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm4_model_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids,
 | 
			
		||||
| 
						 | 
				
			
			@ -236,7 +213,7 @@ def chatglm4_attention_forward(
 | 
			
		|||
 | 
			
		||||
    # apply relative positional encoding (rotary embedding)
 | 
			
		||||
    if isinstance(rotary_pos_emb, tuple) and len(rotary_pos_emb) == 2:
 | 
			
		||||
        # use_fuse_rope, see chatglm2_model_forward
 | 
			
		||||
        # use_fuse_rope, see chatglm4_model_forward
 | 
			
		||||
        cos, sin = rotary_pos_emb
 | 
			
		||||
        rot_dim = cos.shape[-1]
 | 
			
		||||
        query_layer = query_layer.transpose(1, 2)
 | 
			
		||||
| 
						 | 
				
			
			@ -310,7 +287,7 @@ def chatglm4_attention_forward(
 | 
			
		|||
    # core attention computation
 | 
			
		||||
    # ==================================
 | 
			
		||||
 | 
			
		||||
    context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
 | 
			
		||||
    context_layer = core_attn_forward(query_layer, key_layer, value_layer, attention_mask)
 | 
			
		||||
 | 
			
		||||
    # =================
 | 
			
		||||
    # Output. [sq, b, h]
 | 
			
		||||
| 
						 | 
				
			
			@ -319,3 +296,57 @@ def chatglm4_attention_forward(
 | 
			
		|||
    output = self.dense(context_layer)
 | 
			
		||||
 | 
			
		||||
    return output, kv_cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def core_attn_forward(query_layer, key_layer, value_layer, attention_mask):
 | 
			
		||||
    L, S = query_layer.shape[2], key_layer.shape[2]
 | 
			
		||||
    if attention_mask is None and L == S:
 | 
			
		||||
        batch_size, n_head, seq_len, head_dim = query_layer.shape
 | 
			
		||||
        if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
 | 
			
		||||
            # split second dim to block size = 8
 | 
			
		||||
            block_size = 8
 | 
			
		||||
            query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1)
 | 
			
		||||
            key_split = torch.split(key_layer, block_size, dim=1)
 | 
			
		||||
            value_split = torch.split(value_layer, block_size, dim=1)
 | 
			
		||||
            results = []
 | 
			
		||||
            for q, k, v in zip(query_split, key_split, value_split):
 | 
			
		||||
                result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
 | 
			
		||||
                results.append(result)
 | 
			
		||||
            context_layer = torch.cat(results, dim=1)
 | 
			
		||||
        else:
 | 
			
		||||
            context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
 | 
			
		||||
                                                           key_layer,
 | 
			
		||||
                                                           value_layer,
 | 
			
		||||
                                                           is_causal=True).to(key_layer.dtype)
 | 
			
		||||
    else:
 | 
			
		||||
        # attention_mask is not None only when past_key_value is not None and q_len > 1
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
 | 
			
		||||
                                    device=query_layer.device)
 | 
			
		||||
            attention_mask = ~attention_mask
 | 
			
		||||
            if attention_mask.dtype == torch.bool:
 | 
			
		||||
                attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
 | 
			
		||||
            else:
 | 
			
		||||
                attn_bias += attention_mask
 | 
			
		||||
        else:
 | 
			
		||||
            attn_bias = None
 | 
			
		||||
 | 
			
		||||
        if use_sdp(query_layer.shape[2], key_layer.shape[2],
 | 
			
		||||
                   query_layer.shape[-1], query_layer):
 | 
			
		||||
            import xe_addons
 | 
			
		||||
            attn_output = xe_addons.sdp(query_layer, key_layer, value_layer, attn_bias)
 | 
			
		||||
            context_layer = attn_output.view(query_layer.shape)
 | 
			
		||||
        else:
 | 
			
		||||
            head_dim = query_layer.size(-1)
 | 
			
		||||
            attn = torch.matmul(query_layer.to(key_layer.dtype),
 | 
			
		||||
                                key_layer.transpose(2, 3)) / math.sqrt(head_dim)
 | 
			
		||||
            if attn_bias is not None:
 | 
			
		||||
                attn += attn_bias
 | 
			
		||||
            attn = F.softmax(attn, dim=-1,
 | 
			
		||||
                             dtype=torch.float32).to(value_layer.dtype)
 | 
			
		||||
            context_layer = torch.matmul(attn, value_layer)
 | 
			
		||||
    context_layer = context_layer.transpose(1, 2).contiguous()
 | 
			
		||||
    new_context_layer_shape = context_layer.size()[:-2] + (-1,)
 | 
			
		||||
    context_layer = context_layer.reshape(*new_context_layer_shape)
 | 
			
		||||
 | 
			
		||||
    return context_layer
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue