Refine glm1-4 sdp (#11276)
* chatglm * update * update * change chatglm * update sdpa * update * fix style * fix * fix glm * update glm2-32k * update glm2-32k * fix cpu * update * change lower_bound
This commit is contained in:
parent
cffb932f05
commit
592f7aa61e
5 changed files with 144 additions and 334 deletions
|
|
@ -23,6 +23,7 @@ import torch.utils.checkpoint
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
from ipex_llm.transformers.models.chatglm2 import glm_sdpa
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
|
|
@ -103,8 +104,6 @@ def attention_fn(
|
||||||
else:
|
else:
|
||||||
present = None
|
present = None
|
||||||
|
|
||||||
pytorch_major_version = int(torch.__version__.split('.')[0])
|
|
||||||
if pytorch_major_version >= 2:
|
|
||||||
query_layer = query_layer.permute(1, 2, 0, 3)
|
query_layer = query_layer.permute(1, 2, 0, 3)
|
||||||
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
||||||
|
|
||||||
|
|
@ -123,7 +122,7 @@ def attention_fn(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
is_causal=False)
|
is_causal=False)
|
||||||
else:
|
else:
|
||||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
context_layer = glm_sdpa(query_layer,
|
||||||
key_layer,
|
key_layer,
|
||||||
value_layer,
|
value_layer,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
|
@ -163,97 +162,6 @@ def attention_fn(
|
||||||
context_layer = context_layer.reshape(*new_context_layer_shape)
|
context_layer = context_layer.reshape(*new_context_layer_shape)
|
||||||
attention_probs = None
|
attention_probs = None
|
||||||
|
|
||||||
else:
|
|
||||||
query_key_layer_scaling_coeff = float(layer_id + 1)
|
|
||||||
if scaling_attention_score:
|
|
||||||
query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)
|
|
||||||
|
|
||||||
# ===================================
|
|
||||||
# Raw attention scores. [b, np, s, s]
|
|
||||||
# ===================================
|
|
||||||
|
|
||||||
# [b, np, sq, sk]
|
|
||||||
output_size = (query_layer.size(1), query_layer.size(2),
|
|
||||||
query_layer.size(0), key_layer.size(2))
|
|
||||||
|
|
||||||
# [sq, b, np, hn] -> [sq, b * np, hn]
|
|
||||||
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
|
||||||
# [sk, b, np, hn] -> [sk, b * np, hn]
|
|
||||||
key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
|
|
||||||
|
|
||||||
matmul_input_buffer = torch.empty(
|
|
||||||
output_size[0] * output_size[1],
|
|
||||||
output_size[2], output_size[3], dtype=query_layer.dtype,
|
|
||||||
device=query_layer.device
|
|
||||||
)
|
|
||||||
|
|
||||||
matmul_result = torch.empty(
|
|
||||||
output_size[0] * output_size[1],
|
|
||||||
output_size[2], output_size[3], dtype=query_layer.dtype,
|
|
||||||
device=query_layer.device
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.baddbmm(
|
|
||||||
matmul_input_buffer,
|
|
||||||
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
|
||||||
key_layer.transpose(1, 2), # [b * np, hn, sk]
|
|
||||||
beta=0.0,
|
|
||||||
alpha=1.0,
|
|
||||||
out=matmul_result)
|
|
||||||
|
|
||||||
# change view to [b, np, sq, sk]
|
|
||||||
attention_scores = matmul_result.view(*output_size)
|
|
||||||
|
|
||||||
if self.scale_mask_softmax:
|
|
||||||
self.scale_mask_softmax.scale = query_key_layer_scaling_coeff
|
|
||||||
attention_probs = self.scale_mask_softmax(attention_scores,
|
|
||||||
attention_mask.contiguous())
|
|
||||||
else:
|
|
||||||
if not (attention_mask == 0).all():
|
|
||||||
# if auto-regressive, skip
|
|
||||||
attention_scores.masked_fill_(attention_mask, -10000.0)
|
|
||||||
dtype = attention_scores.dtype
|
|
||||||
attention_scores = attention_scores.float()
|
|
||||||
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
|
||||||
|
|
||||||
attention_probs = F.softmax(attention_scores, dim=-1)
|
|
||||||
|
|
||||||
attention_probs = attention_probs.type(dtype)
|
|
||||||
|
|
||||||
# =========================
|
|
||||||
# Context layer. [sq, b, hp]
|
|
||||||
# =========================
|
|
||||||
|
|
||||||
# value_layer -> context layer.
|
|
||||||
# [sk, b, np, hn] --> [b, np, sq, hn]
|
|
||||||
|
|
||||||
# context layer shape: [b, np, sq, hn]
|
|
||||||
output_size = (value_layer.size(0), value_layer.size(1),
|
|
||||||
query_layer.size(0), value_layer.size(3))
|
|
||||||
|
|
||||||
# change view [sk, b * np, hn]
|
|
||||||
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
|
|
||||||
|
|
||||||
# change view [b * np, sq, sk]
|
|
||||||
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
|
||||||
|
|
||||||
# matmul: [b * np, sq, hn]
|
|
||||||
context_layer = torch.empty(
|
|
||||||
output_size[0] * output_size[1],
|
|
||||||
output_size[2], value_layer.size(-1), dtype=value_layer.dtype,
|
|
||||||
device=query_layer.device)
|
|
||||||
torch.bmm(attention_probs, value_layer, out=context_layer)
|
|
||||||
|
|
||||||
# change view [b, np, sq, hn]
|
|
||||||
context_layer = context_layer.view(*output_size)
|
|
||||||
|
|
||||||
# [b, np, sq, hn] --> [sq, b, np, hn]
|
|
||||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
|
||||||
|
|
||||||
# [sq, b, np, hn] --> [sq, b, hp]
|
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)
|
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
|
||||||
|
|
||||||
outputs = (context_layer, present, attention_probs)
|
outputs = (context_layer, present, attention_probs)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
restore_fp8_kv_cache, use_quantize_kv_cache, use_flash_attention
|
||||||
from ipex_llm.transformers.models.utils import use_sdp
|
from ipex_llm.transformers.models.utils import use_sdp
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -60,6 +60,48 @@ def split_tensor_along_last_dim(
|
||||||
return tensor_list
|
return tensor_list
|
||||||
|
|
||||||
|
|
||||||
|
def glm_sdpa(query, key, value, attention_mask=None, is_causal=False):
|
||||||
|
if use_flash_attention(query, key, attention_mask) or query.device.type == 'cpu':
|
||||||
|
context_layer = F.scaled_dot_product_attention(query.to(key.dtype),
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attention_mask,
|
||||||
|
is_causal=is_causal).to(key.dtype)
|
||||||
|
else:
|
||||||
|
# attention_mask is not None only when past_key_value is not None and q_len > 1
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_bias = torch.zeros(attention_mask.shape, dtype=query.dtype,
|
||||||
|
device=query.device)
|
||||||
|
attention_mask = ~attention_mask
|
||||||
|
if attention_mask.dtype == torch.bool:
|
||||||
|
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
||||||
|
else:
|
||||||
|
attn_bias += attention_mask
|
||||||
|
elif is_causal:
|
||||||
|
L, S = query.size(-2), key.size(-2)
|
||||||
|
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
||||||
|
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
||||||
|
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||||
|
attn_bias.to(key.dtype)
|
||||||
|
else:
|
||||||
|
attn_bias = None
|
||||||
|
if use_sdp(query.shape[2], key.shape[2],
|
||||||
|
query.shape[-1], query):
|
||||||
|
import xe_addons
|
||||||
|
attn_output = xe_addons.sdp(query, key, value, attn_bias)
|
||||||
|
context_layer = attn_output.view(query.shape)
|
||||||
|
else:
|
||||||
|
head_dim = query.size(-1)
|
||||||
|
attn = torch.matmul(query.to(key.dtype),
|
||||||
|
key.transpose(2, 3)) / math.sqrt(head_dim)
|
||||||
|
if attn_bias is not None:
|
||||||
|
attn += attn_bias
|
||||||
|
attn = F.softmax(attn, dim=-1,
|
||||||
|
dtype=torch.float32).to(value.dtype)
|
||||||
|
context_layer = torch.matmul(attn, value)
|
||||||
|
return context_layer
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
||||||
# x: [sq, b, np, hn]
|
# x: [sq, b, np, hn]
|
||||||
|
|
@ -271,19 +313,11 @@ def chatglm2_quantized_attention_forward_8eb45c(
|
||||||
value_split = torch.split(value, block_size, dim=1)
|
value_split = torch.split(value, block_size, dim=1)
|
||||||
results = []
|
results = []
|
||||||
for q, k, v in zip(query_split, key_split, value_split):
|
for q, k, v in zip(query_split, key_split, value_split):
|
||||||
if attention_mask is None:
|
result = glm_sdpa(q, k, v, is_causal=True)
|
||||||
result = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
|
||||||
else:
|
|
||||||
result = F.scaled_dot_product_attention(q, k, v, attention_mask)
|
|
||||||
results.append(result)
|
results.append(result)
|
||||||
context_layer = torch.cat(results, dim=1)
|
context_layer = torch.cat(results, dim=1)
|
||||||
else:
|
else:
|
||||||
if attention_mask is None:
|
context_layer = glm_sdpa(query_layer, key, value, is_causal=True)
|
||||||
context_layer = F.scaled_dot_product_attention(query_layer, key,
|
|
||||||
value, is_causal=True)
|
|
||||||
else:
|
|
||||||
context_layer = F.scaled_dot_product_attention(query_layer, key,
|
|
||||||
value, attention_mask)
|
|
||||||
context_layer = context_layer.to(query_layer.dtype)
|
context_layer = context_layer.to(query_layer.dtype)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
|
@ -535,145 +569,34 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
|
|
||||||
|
|
||||||
def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask):
|
def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask):
|
||||||
pytorch_major_version = int(torch.__version__.split('.')[0])
|
|
||||||
if pytorch_major_version >= 2:
|
|
||||||
query_layer = query_layer.permute(1, 2, 0, 3)
|
query_layer = query_layer.permute(1, 2, 0, 3)
|
||||||
L, S = query_layer.shape[2], key_layer.shape[2]
|
L, S = query_layer.shape[2], key_layer.shape[2]
|
||||||
if attention_mask is None and L == S:
|
|
||||||
batch_size, n_head, seq_len, head_dim = query_layer.shape
|
batch_size, n_head, seq_len, head_dim = query_layer.shape
|
||||||
|
if attention_mask is None and L == S:
|
||||||
if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
|
if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
|
||||||
# split second dim to block size = 8
|
# split second dim to block size = 8
|
||||||
block_size = 8
|
block_size = 8
|
||||||
query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1)
|
query_layer = query_layer.to(key_layer.dtype)
|
||||||
|
query_split = torch.split(query_layer, block_size, dim=1)
|
||||||
key_split = torch.split(key_layer, block_size, dim=1)
|
key_split = torch.split(key_layer, block_size, dim=1)
|
||||||
value_split = torch.split(value_layer, block_size, dim=1)
|
value_split = torch.split(value_layer, block_size, dim=1)
|
||||||
results = []
|
results = []
|
||||||
for q, k, v in zip(query_split, key_split, value_split):
|
for q, k, v in zip(query_split, key_split, value_split):
|
||||||
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
|
result = glm_sdpa(q, k, v, is_causal=True)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
context_layer = torch.cat(results, dim=1)
|
context_layer = torch.cat(results, dim=1)
|
||||||
else:
|
else:
|
||||||
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
|
context_layer = glm_sdpa(query_layer,
|
||||||
key_layer,
|
key_layer,
|
||||||
value_layer,
|
value_layer,
|
||||||
is_causal=True).to(key_layer.dtype)
|
is_causal=True)
|
||||||
else:
|
else:
|
||||||
# attention_mask is not None only when past_key_value is not None and q_len > 1
|
context_layer = glm_sdpa(query_layer,
|
||||||
if attention_mask is not None:
|
key_layer,
|
||||||
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
|
value_layer,
|
||||||
device=query_layer.device)
|
attention_mask)
|
||||||
attention_mask = ~attention_mask
|
|
||||||
if attention_mask.dtype == torch.bool:
|
|
||||||
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
|
||||||
else:
|
|
||||||
attn_bias += attention_mask
|
|
||||||
else:
|
|
||||||
attn_bias = None
|
|
||||||
|
|
||||||
if use_sdp(query_layer.shape[2], key_layer.shape[2],
|
|
||||||
query_layer.shape[-1], query_layer):
|
|
||||||
import xe_addons
|
|
||||||
attn_output = xe_addons.sdp(query_layer, key_layer, value_layer, attn_bias)
|
|
||||||
context_layer = attn_output.view(query_layer.shape)
|
|
||||||
else:
|
|
||||||
head_dim = query_layer.size(-1)
|
|
||||||
attn = torch.matmul(query_layer.to(key_layer.dtype),
|
|
||||||
key_layer.transpose(2, 3)) / math.sqrt(head_dim)
|
|
||||||
if attn_bias is not None:
|
|
||||||
attn += attn_bias
|
|
||||||
attn = F.softmax(attn, dim=-1,
|
|
||||||
dtype=torch.float32).to(value_layer.dtype)
|
|
||||||
context_layer = torch.matmul(attn, value_layer)
|
|
||||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
||||||
context_layer = context_layer.reshape(*new_context_layer_shape)
|
context_layer = context_layer.reshape(*new_context_layer_shape)
|
||||||
else:
|
|
||||||
# Raw attention scores
|
|
||||||
|
|
||||||
# [b, np, sq, sk]
|
|
||||||
output_size = (query_layer.size(1), query_layer.size(2),
|
|
||||||
query_layer.size(0), key_layer.size(2))
|
|
||||||
|
|
||||||
# [sq, b, np, hn] -> [sq, b * np, hn]
|
|
||||||
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
|
||||||
# [sk, b, np, hn] -> [sk, b * np, hn]
|
|
||||||
key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
|
|
||||||
|
|
||||||
# preallocting input tensor: [b * np, sq, sk]
|
|
||||||
matmul_input_buffer = torch.empty(
|
|
||||||
output_size[0] * output_size[1],
|
|
||||||
output_size[2], output_size[3], dtype=query_layer.dtype,
|
|
||||||
device=query_layer.device
|
|
||||||
)
|
|
||||||
|
|
||||||
matmul_result = torch.empty(
|
|
||||||
output_size[0] * output_size[1],
|
|
||||||
output_size[2], output_size[3], dtype=query_layer.dtype,
|
|
||||||
device=query_layer.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Raw attention scores. [b * np, sq, sk]
|
|
||||||
torch.baddbmm(
|
|
||||||
matmul_input_buffer,
|
|
||||||
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
|
||||||
key_layer.transpose(1, 2), # [b * np, hn, sk]
|
|
||||||
beta=0.0,
|
|
||||||
alpha=(1.0 / self.norm_factor),
|
|
||||||
out=matmul_result
|
|
||||||
)
|
|
||||||
|
|
||||||
# change view to [b, np, sq, sk]
|
|
||||||
attention_scores = matmul_result.view(*output_size)
|
|
||||||
|
|
||||||
# ===========================
|
|
||||||
# Attention probs and dropout
|
|
||||||
# ===========================
|
|
||||||
|
|
||||||
# attention scores and attention mask [b, np, sq, sk]
|
|
||||||
if self.attention_softmax_in_fp32:
|
|
||||||
attention_scores = attention_scores.float()
|
|
||||||
if self.coeff is not None:
|
|
||||||
attention_scores = attention_scores * self.coeff
|
|
||||||
if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
|
|
||||||
attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
|
|
||||||
device=attention_scores.device, dtype=torch.bool)
|
|
||||||
attention_mask.tril_()
|
|
||||||
attention_mask = ~attention_mask
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
|
|
||||||
attention_probs = F.softmax(attention_scores, dim=-1)
|
|
||||||
attention_probs = attention_probs.type_as(value_layer)
|
|
||||||
|
|
||||||
# This is actually dropping out entire tokens to attend to, which might
|
|
||||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
||||||
attention_probs = self.attention_dropout(attention_probs)
|
|
||||||
# =========================
|
|
||||||
# Context layer. [sq, b, hp]
|
|
||||||
# =========================
|
|
||||||
|
|
||||||
# value_layer -> context layer.
|
|
||||||
# [sk, b, np, hn] --> [b, np, sq, hn]
|
|
||||||
|
|
||||||
# context layer shape: [b, np, sq, hn]
|
|
||||||
output_size = (value_layer.size(0), value_layer.size(1),
|
|
||||||
query_layer.size(0), value_layer.size(3))
|
|
||||||
# change view [sk, b * np, hn]
|
|
||||||
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
|
|
||||||
# change view [b * np, sq, sk]
|
|
||||||
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
|
||||||
# matmul: [b * np, sq, hn]
|
|
||||||
context_layer = torch.empty(
|
|
||||||
output_size[0] * output_size[1],
|
|
||||||
output_size[2], value_layer.size(-1), dtype=value_layer.dtype,
|
|
||||||
device=value_layer.device,
|
|
||||||
)
|
|
||||||
torch.bmm(attention_probs, value_layer, out=context_layer)
|
|
||||||
# change view [b, np, sq, hn]
|
|
||||||
context_layer = context_layer.view(*output_size)
|
|
||||||
# [b, np, sq, hn] --> [sq, b, np, hn]
|
|
||||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
|
||||||
# [sq, b, np, hn] --> [sq, b, hp]
|
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
|
||||||
|
|
||||||
return context_layer
|
return context_layer
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import torch
|
||||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
from ipex_llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -179,10 +180,9 @@ def chatglm2_32k_attention_forward(
|
||||||
key_layer = key_cache
|
key_layer = key_cache
|
||||||
value_layer = value_cache
|
value_layer = value_cache
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
key_layer = key_layer.permute(2, 0, 1, 3)
|
key_layer = key_layer.permute(2, 0, 1, 3)
|
||||||
value_layer = value_layer.permute(2, 0, 1, 3)
|
value_layer = value_layer.permute(2, 0, 1, 3)
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
if kv_cache is None:
|
if kv_cache is None:
|
||||||
kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0),
|
kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0),
|
||||||
value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
|
value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
|
||||||
|
|
@ -195,7 +195,7 @@ def chatglm2_32k_attention_forward(
|
||||||
# core attention computation
|
# core attention computation
|
||||||
# ==================================
|
# ==================================
|
||||||
|
|
||||||
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
context_layer = core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask)
|
||||||
|
|
||||||
# =================
|
# =================
|
||||||
# Output. [sq, b, h]
|
# Output. [sq, b, h]
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ import torch.nn.functional as F
|
||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, apply_ipex_rotate_every_two
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, apply_ipex_rotate_every_two
|
||||||
from ipex_llm.transformers.models.utils import use_sdp
|
from ipex_llm.transformers.models.utils import use_sdp
|
||||||
from ipex_llm.transformers.models.chatglm2 import should_split_qkv_tensor
|
from ipex_llm.transformers.models.chatglm2 import should_split_qkv_tensor, glm_sdpa
|
||||||
from ipex_llm.transformers.models.chatglm2 import split_tensor_along_last_dim
|
from ipex_llm.transformers.models.chatglm2 import split_tensor_along_last_dim
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
|
||||||
|
|
@ -300,51 +300,30 @@ def chatglm4_attention_forward(
|
||||||
|
|
||||||
def core_attn_forward(query_layer, key_layer, value_layer, attention_mask):
|
def core_attn_forward(query_layer, key_layer, value_layer, attention_mask):
|
||||||
L, S = query_layer.shape[2], key_layer.shape[2]
|
L, S = query_layer.shape[2], key_layer.shape[2]
|
||||||
if attention_mask is None and L == S:
|
|
||||||
batch_size, n_head, seq_len, head_dim = query_layer.shape
|
batch_size, n_head, seq_len, head_dim = query_layer.shape
|
||||||
|
if attention_mask is None and L == S:
|
||||||
if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
|
if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
|
||||||
# split second dim to block size = 8
|
# split second dim to block size = 8
|
||||||
block_size = 8
|
block_size = 8
|
||||||
query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1)
|
query_layer = query_layer.to(key_layer.dtype)
|
||||||
|
query_split = torch.split(query_layer, block_size, dim=1)
|
||||||
key_split = torch.split(key_layer, block_size, dim=1)
|
key_split = torch.split(key_layer, block_size, dim=1)
|
||||||
value_split = torch.split(value_layer, block_size, dim=1)
|
value_split = torch.split(value_layer, block_size, dim=1)
|
||||||
results = []
|
results = []
|
||||||
for q, k, v in zip(query_split, key_split, value_split):
|
for q, k, v in zip(query_split, key_split, value_split):
|
||||||
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
|
result = glm_sdpa(q, k, v, is_causal=True)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
context_layer = torch.cat(results, dim=1)
|
context_layer = torch.cat(results, dim=1)
|
||||||
else:
|
else:
|
||||||
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
|
context_layer = glm_sdpa(query_layer,
|
||||||
key_layer,
|
key_layer,
|
||||||
value_layer,
|
value_layer,
|
||||||
is_causal=True).to(key_layer.dtype)
|
is_causal=True)
|
||||||
else:
|
else:
|
||||||
# attention_mask is not None only when past_key_value is not None and q_len > 1
|
context_layer = glm_sdpa(query_layer,
|
||||||
if attention_mask is not None:
|
key_layer,
|
||||||
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
|
value_layer,
|
||||||
device=query_layer.device)
|
attention_mask)
|
||||||
attention_mask = ~attention_mask
|
|
||||||
if attention_mask.dtype == torch.bool:
|
|
||||||
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
|
||||||
else:
|
|
||||||
attn_bias += attention_mask
|
|
||||||
else:
|
|
||||||
attn_bias = None
|
|
||||||
|
|
||||||
if use_sdp(query_layer.shape[2], key_layer.shape[2],
|
|
||||||
query_layer.shape[-1], query_layer):
|
|
||||||
import xe_addons
|
|
||||||
attn_output = xe_addons.sdp(query_layer, key_layer, value_layer, attn_bias)
|
|
||||||
context_layer = attn_output.view(query_layer.shape)
|
|
||||||
else:
|
|
||||||
head_dim = query_layer.size(-1)
|
|
||||||
attn = torch.matmul(query_layer.to(key_layer.dtype),
|
|
||||||
key_layer.transpose(2, 3)) / math.sqrt(head_dim)
|
|
||||||
if attn_bias is not None:
|
|
||||||
attn += attn_bias
|
|
||||||
attn = F.softmax(attn, dim=-1,
|
|
||||||
dtype=torch.float32).to(value_layer.dtype)
|
|
||||||
context_layer = torch.matmul(attn, value_layer)
|
|
||||||
context_layer = context_layer.transpose(1, 2).contiguous()
|
context_layer = context_layer.transpose(1, 2).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
||||||
context_layer = context_layer.reshape(*new_context_layer_shape)
|
context_layer = context_layer.reshape(*new_context_layer_shape)
|
||||||
|
|
|
||||||
|
|
@ -171,7 +171,7 @@ class Test_Optimize_Gpu_Model:
|
||||||
# currently only need to compare the output of one self-attention layer.
|
# currently only need to compare the output of one self-attention layer.
|
||||||
layer_norm = "transformer.encoder.layers.27.input_layernorm"
|
layer_norm = "transformer.encoder.layers.27.input_layernorm"
|
||||||
self_attn = "transformer.encoder.layers.27.self_attention"
|
self_attn = "transformer.encoder.layers.27.self_attention"
|
||||||
lower_bound = 4e-3
|
lower_bound = 8e-3
|
||||||
self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, self_attn, layer_norm, lower_bound)
|
self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, self_attn, layer_norm, lower_bound)
|
||||||
|
|
||||||
def Mistral_gpu_model(self, Name, Model, Tokenizer, model_path):
|
def Mistral_gpu_model(self, Name, Model, Tokenizer, model_path):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue