parent
151fcf37bb
commit
dbc3c2d72d
1 changed files with 59 additions and 28 deletions
|
|
@ -22,6 +22,9 @@ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, apply_ipex_rotate_every_two
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, apply_ipex_rotate_every_two
|
||||||
|
from ipex_llm.transformers.models.utils import use_sdp
|
||||||
|
from ipex_llm.transformers.models.chatglm2 import should_split_qkv_tensor
|
||||||
|
from ipex_llm.transformers.models.chatglm2 import split_tensor_along_last_dim
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -31,32 +34,6 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH",
|
||||||
KV_CACHE_ALLOC_MIN_LENGTH = 512
|
KV_CACHE_ALLOC_MIN_LENGTH = 512
|
||||||
|
|
||||||
|
|
||||||
def split_tensor_along_last_dim(
|
|
||||||
tensor: torch.Tensor,
|
|
||||||
num_partitions: int,
|
|
||||||
contiguous_split_chunks: bool = False,
|
|
||||||
) -> List[torch.Tensor]:
|
|
||||||
"""Split a tensor along its last dimension.
|
|
||||||
Arguments:
|
|
||||||
tensor: input tensor.
|
|
||||||
num_partitions: number of partitions to split the tensor
|
|
||||||
contiguous_split_chunks: If True, make each chunk contiguous
|
|
||||||
in memory.
|
|
||||||
Returns:
|
|
||||||
A list of Tensors
|
|
||||||
"""
|
|
||||||
# Get the size and dimension.
|
|
||||||
last_dim = tensor.dim() - 1
|
|
||||||
last_dim_size = tensor.size()[last_dim] // num_partitions
|
|
||||||
# Split.
|
|
||||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
|
||||||
# Note: torch.split does not create contiguous tensors by default.
|
|
||||||
if contiguous_split_chunks:
|
|
||||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
|
||||||
|
|
||||||
return tensor_list
|
|
||||||
|
|
||||||
|
|
||||||
def chatglm4_model_forward(
|
def chatglm4_model_forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
|
@ -236,7 +213,7 @@ def chatglm4_attention_forward(
|
||||||
|
|
||||||
# apply relative positional encoding (rotary embedding)
|
# apply relative positional encoding (rotary embedding)
|
||||||
if isinstance(rotary_pos_emb, tuple) and len(rotary_pos_emb) == 2:
|
if isinstance(rotary_pos_emb, tuple) and len(rotary_pos_emb) == 2:
|
||||||
# use_fuse_rope, see chatglm2_model_forward
|
# use_fuse_rope, see chatglm4_model_forward
|
||||||
cos, sin = rotary_pos_emb
|
cos, sin = rotary_pos_emb
|
||||||
rot_dim = cos.shape[-1]
|
rot_dim = cos.shape[-1]
|
||||||
query_layer = query_layer.transpose(1, 2)
|
query_layer = query_layer.transpose(1, 2)
|
||||||
|
|
@ -310,7 +287,7 @@ def chatglm4_attention_forward(
|
||||||
# core attention computation
|
# core attention computation
|
||||||
# ==================================
|
# ==================================
|
||||||
|
|
||||||
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
context_layer = core_attn_forward(query_layer, key_layer, value_layer, attention_mask)
|
||||||
|
|
||||||
# =================
|
# =================
|
||||||
# Output. [sq, b, h]
|
# Output. [sq, b, h]
|
||||||
|
|
@ -319,3 +296,57 @@ def chatglm4_attention_forward(
|
||||||
output = self.dense(context_layer)
|
output = self.dense(context_layer)
|
||||||
|
|
||||||
return output, kv_cache
|
return output, kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
def core_attn_forward(query_layer, key_layer, value_layer, attention_mask):
|
||||||
|
L, S = query_layer.shape[2], key_layer.shape[2]
|
||||||
|
if attention_mask is None and L == S:
|
||||||
|
batch_size, n_head, seq_len, head_dim = query_layer.shape
|
||||||
|
if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
|
||||||
|
# split second dim to block size = 8
|
||||||
|
block_size = 8
|
||||||
|
query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1)
|
||||||
|
key_split = torch.split(key_layer, block_size, dim=1)
|
||||||
|
value_split = torch.split(value_layer, block_size, dim=1)
|
||||||
|
results = []
|
||||||
|
for q, k, v in zip(query_split, key_split, value_split):
|
||||||
|
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
|
||||||
|
results.append(result)
|
||||||
|
context_layer = torch.cat(results, dim=1)
|
||||||
|
else:
|
||||||
|
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
is_causal=True).to(key_layer.dtype)
|
||||||
|
else:
|
||||||
|
# attention_mask is not None only when past_key_value is not None and q_len > 1
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
|
||||||
|
device=query_layer.device)
|
||||||
|
attention_mask = ~attention_mask
|
||||||
|
if attention_mask.dtype == torch.bool:
|
||||||
|
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
||||||
|
else:
|
||||||
|
attn_bias += attention_mask
|
||||||
|
else:
|
||||||
|
attn_bias = None
|
||||||
|
|
||||||
|
if use_sdp(query_layer.shape[2], key_layer.shape[2],
|
||||||
|
query_layer.shape[-1], query_layer):
|
||||||
|
import xe_addons
|
||||||
|
attn_output = xe_addons.sdp(query_layer, key_layer, value_layer, attn_bias)
|
||||||
|
context_layer = attn_output.view(query_layer.shape)
|
||||||
|
else:
|
||||||
|
head_dim = query_layer.size(-1)
|
||||||
|
attn = torch.matmul(query_layer.to(key_layer.dtype),
|
||||||
|
key_layer.transpose(2, 3)) / math.sqrt(head_dim)
|
||||||
|
if attn_bias is not None:
|
||||||
|
attn += attn_bias
|
||||||
|
attn = F.softmax(attn, dim=-1,
|
||||||
|
dtype=torch.float32).to(value_layer.dtype)
|
||||||
|
context_layer = torch.matmul(attn, value_layer)
|
||||||
|
context_layer = context_layer.transpose(1, 2).contiguous()
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
||||||
|
context_layer = context_layer.reshape(*new_context_layer_shape)
|
||||||
|
|
||||||
|
return context_layer
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue