Refine glm1-4 sdp (#11276)

* chatglm

* update

* update

* change chatglm

* update sdpa

* update

* fix style

* fix

* fix glm

* update glm2-32k

* update glm2-32k

* fix cpu

* update

* change lower_bound
This commit is contained in:
Xin Qiu 2024-06-12 17:11:56 +08:00 committed by GitHub
parent cffb932f05
commit 592f7aa61e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 144 additions and 334 deletions

View file

@ -23,6 +23,7 @@ import torch.utils.checkpoint
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, Tuple from typing import Optional, Tuple
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.chatglm2 import glm_sdpa
def rotate_half(x): def rotate_half(x):
@ -103,8 +104,6 @@ def attention_fn(
else: else:
present = None present = None
pytorch_major_version = int(torch.__version__.split('.')[0])
if pytorch_major_version >= 2:
query_layer = query_layer.permute(1, 2, 0, 3) query_layer = query_layer.permute(1, 2, 0, 3)
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
@ -123,7 +122,7 @@ def attention_fn(
attention_mask, attention_mask,
is_causal=False) is_causal=False)
else: else:
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, context_layer = glm_sdpa(query_layer,
key_layer, key_layer,
value_layer, value_layer,
attention_mask, attention_mask,
@ -163,97 +162,6 @@ def attention_fn(
context_layer = context_layer.reshape(*new_context_layer_shape) context_layer = context_layer.reshape(*new_context_layer_shape)
attention_probs = None attention_probs = None
else:
query_key_layer_scaling_coeff = float(layer_id + 1)
if scaling_attention_score:
query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1), query_layer.size(2),
query_layer.size(0), key_layer.size(2))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
matmul_input_buffer = torch.empty(
output_size[0] * output_size[1],
output_size[2], output_size[3], dtype=query_layer.dtype,
device=query_layer.device
)
matmul_result = torch.empty(
output_size[0] * output_size[1],
output_size[2], output_size[3], dtype=query_layer.dtype,
device=query_layer.device
)
torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=1.0,
out=matmul_result)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
if self.scale_mask_softmax:
self.scale_mask_softmax.scale = query_key_layer_scaling_coeff
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask.contiguous())
else:
if not (attention_mask == 0).all():
# if auto-regressive, skip
attention_scores.masked_fill_(attention_mask, -10000.0)
dtype = attention_scores.dtype
attention_scores = attention_scores.float()
attention_scores = attention_scores * query_key_layer_scaling_coeff
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.type(dtype)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(0), value_layer.size(1),
query_layer.size(0), value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.empty(
output_size[0] * output_size[1],
output_size[2], value_layer.size(-1), dtype=value_layer.dtype,
device=query_layer.device)
torch.bmm(attention_probs, value_layer, out=context_layer)
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, present, attention_probs) outputs = (context_layer, present, attention_probs)
return outputs return outputs

View file

@ -24,7 +24,7 @@ import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache restore_fp8_kv_cache, use_quantize_kv_cache, use_flash_attention
from ipex_llm.transformers.models.utils import use_sdp from ipex_llm.transformers.models.utils import use_sdp
@ -60,6 +60,48 @@ def split_tensor_along_last_dim(
return tensor_list return tensor_list
def glm_sdpa(query, key, value, attention_mask=None, is_causal=False):
if use_flash_attention(query, key, attention_mask) or query.device.type == 'cpu':
context_layer = F.scaled_dot_product_attention(query.to(key.dtype),
key,
value,
attention_mask,
is_causal=is_causal).to(key.dtype)
else:
# attention_mask is not None only when past_key_value is not None and q_len > 1
if attention_mask is not None:
attn_bias = torch.zeros(attention_mask.shape, dtype=query.dtype,
device=query.device)
attention_mask = ~attention_mask
if attention_mask.dtype == torch.bool:
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
else:
attn_bias += attention_mask
elif is_causal:
L, S = query.size(-2), key.size(-2)
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(key.dtype)
else:
attn_bias = None
if use_sdp(query.shape[2], key.shape[2],
query.shape[-1], query):
import xe_addons
attn_output = xe_addons.sdp(query, key, value, attn_bias)
context_layer = attn_output.view(query.shape)
else:
head_dim = query.size(-1)
attn = torch.matmul(query.to(key.dtype),
key.transpose(2, 3)) / math.sqrt(head_dim)
if attn_bias is not None:
attn += attn_bias
attn = F.softmax(attn, dim=-1,
dtype=torch.float32).to(value.dtype)
context_layer = torch.matmul(attn, value)
return context_layer
@torch.jit.script @torch.jit.script
def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
# x: [sq, b, np, hn] # x: [sq, b, np, hn]
@ -271,19 +313,11 @@ def chatglm2_quantized_attention_forward_8eb45c(
value_split = torch.split(value, block_size, dim=1) value_split = torch.split(value, block_size, dim=1)
results = [] results = []
for q, k, v in zip(query_split, key_split, value_split): for q, k, v in zip(query_split, key_split, value_split):
if attention_mask is None: result = glm_sdpa(q, k, v, is_causal=True)
result = F.scaled_dot_product_attention(q, k, v, is_causal=True)
else:
result = F.scaled_dot_product_attention(q, k, v, attention_mask)
results.append(result) results.append(result)
context_layer = torch.cat(results, dim=1) context_layer = torch.cat(results, dim=1)
else: else:
if attention_mask is None: context_layer = glm_sdpa(query_layer, key, value, is_causal=True)
context_layer = F.scaled_dot_product_attention(query_layer, key,
value, is_causal=True)
else:
context_layer = F.scaled_dot_product_attention(query_layer, key,
value, attention_mask)
context_layer = context_layer.to(query_layer.dtype) context_layer = context_layer.to(query_layer.dtype)
if use_cache: if use_cache:
@ -535,145 +569,34 @@ def chatglm2_attention_forward_8eb45c(
def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask): def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask):
pytorch_major_version = int(torch.__version__.split('.')[0])
if pytorch_major_version >= 2:
query_layer = query_layer.permute(1, 2, 0, 3) query_layer = query_layer.permute(1, 2, 0, 3)
L, S = query_layer.shape[2], key_layer.shape[2] L, S = query_layer.shape[2], key_layer.shape[2]
if attention_mask is None and L == S:
batch_size, n_head, seq_len, head_dim = query_layer.shape batch_size, n_head, seq_len, head_dim = query_layer.shape
if attention_mask is None and L == S:
if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len): if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
# split second dim to block size = 8 # split second dim to block size = 8
block_size = 8 block_size = 8
query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1) query_layer = query_layer.to(key_layer.dtype)
query_split = torch.split(query_layer, block_size, dim=1)
key_split = torch.split(key_layer, block_size, dim=1) key_split = torch.split(key_layer, block_size, dim=1)
value_split = torch.split(value_layer, block_size, dim=1) value_split = torch.split(value_layer, block_size, dim=1)
results = [] results = []
for q, k, v in zip(query_split, key_split, value_split): for q, k, v in zip(query_split, key_split, value_split):
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype) result = glm_sdpa(q, k, v, is_causal=True)
results.append(result) results.append(result)
context_layer = torch.cat(results, dim=1) context_layer = torch.cat(results, dim=1)
else: else:
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype), context_layer = glm_sdpa(query_layer,
key_layer, key_layer,
value_layer, value_layer,
is_causal=True).to(key_layer.dtype) is_causal=True)
else: else:
# attention_mask is not None only when past_key_value is not None and q_len > 1 context_layer = glm_sdpa(query_layer,
if attention_mask is not None: key_layer,
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, value_layer,
device=query_layer.device) attention_mask)
attention_mask = ~attention_mask
if attention_mask.dtype == torch.bool:
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
else:
attn_bias += attention_mask
else:
attn_bias = None
if use_sdp(query_layer.shape[2], key_layer.shape[2],
query_layer.shape[-1], query_layer):
import xe_addons
attn_output = xe_addons.sdp(query_layer, key_layer, value_layer, attn_bias)
context_layer = attn_output.view(query_layer.shape)
else:
head_dim = query_layer.size(-1)
attn = torch.matmul(query_layer.to(key_layer.dtype),
key_layer.transpose(2, 3)) / math.sqrt(head_dim)
if attn_bias is not None:
attn += attn_bias
attn = F.softmax(attn, dim=-1,
dtype=torch.float32).to(value_layer.dtype)
context_layer = torch.matmul(attn, value_layer)
context_layer = context_layer.permute(2, 0, 1, 3) context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (-1,) new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.reshape(*new_context_layer_shape) context_layer = context_layer.reshape(*new_context_layer_shape)
else:
# Raw attention scores
# [b, np, sq, sk]
output_size = (query_layer.size(1), query_layer.size(2),
query_layer.size(0), key_layer.size(2))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = torch.empty(
output_size[0] * output_size[1],
output_size[2], output_size[3], dtype=query_layer.dtype,
device=query_layer.device
)
matmul_result = torch.empty(
output_size[0] * output_size[1],
output_size[2], output_size[3], dtype=query_layer.dtype,
device=query_layer.device
)
# Raw attention scores. [b * np, sq, sk]
torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
out=matmul_result
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
if self.attention_softmax_in_fp32:
attention_scores = attention_scores.float()
if self.coeff is not None:
attention_scores = attention_scores * self.coeff
if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
device=attention_scores.device, dtype=torch.bool)
attention_mask.tril_()
attention_mask = ~attention_mask
if attention_mask is not None:
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.type_as(value_layer)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(0), value_layer.size(1),
query_layer.size(0), value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.empty(
output_size[0] * output_size[1],
output_size[2], value_layer.size(-1), dtype=value_layer.dtype,
device=value_layer.device,
)
torch.bmm(attention_probs, value_layer, out=context_layer)
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer return context_layer

View file

@ -21,6 +21,7 @@ import torch
from typing import Optional, Tuple, Union, List, Callable, Dict, Any from typing import Optional, Tuple, Union, List, Callable, Dict, Any
import torch.nn.functional as F import torch.nn.functional as F
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
import os import os
@ -179,10 +180,9 @@ def chatglm2_32k_attention_forward(
key_layer = key_cache key_layer = key_cache
value_layer = value_cache value_layer = value_cache
if use_cache:
key_layer = key_layer.permute(2, 0, 1, 3) key_layer = key_layer.permute(2, 0, 1, 3)
value_layer = value_layer.permute(2, 0, 1, 3) value_layer = value_layer.permute(2, 0, 1, 3)
if use_cache:
if kv_cache is None: if kv_cache is None:
kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0),
value_layer.unsqueeze(0).unsqueeze(0)), dim=1) value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
@ -195,7 +195,7 @@ def chatglm2_32k_attention_forward(
# core attention computation # core attention computation
# ================================== # ==================================
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) context_layer = core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask)
# ================= # =================
# Output. [sq, b, h] # Output. [sq, b, h]

View file

@ -23,7 +23,7 @@ import torch.nn.functional as F
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, apply_ipex_rotate_every_two from ipex_llm.transformers.models.utils import use_quantize_kv_cache, apply_ipex_rotate_every_two
from ipex_llm.transformers.models.utils import use_sdp from ipex_llm.transformers.models.utils import use_sdp
from ipex_llm.transformers.models.chatglm2 import should_split_qkv_tensor from ipex_llm.transformers.models.chatglm2 import should_split_qkv_tensor, glm_sdpa
from ipex_llm.transformers.models.chatglm2 import split_tensor_along_last_dim from ipex_llm.transformers.models.chatglm2 import split_tensor_along_last_dim
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
@ -300,51 +300,30 @@ def chatglm4_attention_forward(
def core_attn_forward(query_layer, key_layer, value_layer, attention_mask): def core_attn_forward(query_layer, key_layer, value_layer, attention_mask):
L, S = query_layer.shape[2], key_layer.shape[2] L, S = query_layer.shape[2], key_layer.shape[2]
if attention_mask is None and L == S:
batch_size, n_head, seq_len, head_dim = query_layer.shape batch_size, n_head, seq_len, head_dim = query_layer.shape
if attention_mask is None and L == S:
if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len): if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
# split second dim to block size = 8 # split second dim to block size = 8
block_size = 8 block_size = 8
query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1) query_layer = query_layer.to(key_layer.dtype)
query_split = torch.split(query_layer, block_size, dim=1)
key_split = torch.split(key_layer, block_size, dim=1) key_split = torch.split(key_layer, block_size, dim=1)
value_split = torch.split(value_layer, block_size, dim=1) value_split = torch.split(value_layer, block_size, dim=1)
results = [] results = []
for q, k, v in zip(query_split, key_split, value_split): for q, k, v in zip(query_split, key_split, value_split):
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype) result = glm_sdpa(q, k, v, is_causal=True)
results.append(result) results.append(result)
context_layer = torch.cat(results, dim=1) context_layer = torch.cat(results, dim=1)
else: else:
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype), context_layer = glm_sdpa(query_layer,
key_layer, key_layer,
value_layer, value_layer,
is_causal=True).to(key_layer.dtype) is_causal=True)
else: else:
# attention_mask is not None only when past_key_value is not None and q_len > 1 context_layer = glm_sdpa(query_layer,
if attention_mask is not None: key_layer,
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, value_layer,
device=query_layer.device) attention_mask)
attention_mask = ~attention_mask
if attention_mask.dtype == torch.bool:
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
else:
attn_bias += attention_mask
else:
attn_bias = None
if use_sdp(query_layer.shape[2], key_layer.shape[2],
query_layer.shape[-1], query_layer):
import xe_addons
attn_output = xe_addons.sdp(query_layer, key_layer, value_layer, attn_bias)
context_layer = attn_output.view(query_layer.shape)
else:
head_dim = query_layer.size(-1)
attn = torch.matmul(query_layer.to(key_layer.dtype),
key_layer.transpose(2, 3)) / math.sqrt(head_dim)
if attn_bias is not None:
attn += attn_bias
attn = F.softmax(attn, dim=-1,
dtype=torch.float32).to(value_layer.dtype)
context_layer = torch.matmul(attn, value_layer)
context_layer = context_layer.transpose(1, 2).contiguous() context_layer = context_layer.transpose(1, 2).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (-1,) new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.reshape(*new_context_layer_shape) context_layer = context_layer.reshape(*new_context_layer_shape)

View file

@ -171,7 +171,7 @@ class Test_Optimize_Gpu_Model:
# currently only need to compare the output of one self-attention layer. # currently only need to compare the output of one self-attention layer.
layer_norm = "transformer.encoder.layers.27.input_layernorm" layer_norm = "transformer.encoder.layers.27.input_layernorm"
self_attn = "transformer.encoder.layers.27.self_attention" self_attn = "transformer.encoder.layers.27.self_attention"
lower_bound = 4e-3 lower_bound = 8e-3
self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, self_attn, layer_norm, lower_bound) self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, self_attn, layer_norm, lower_bound)
def Mistral_gpu_model(self, Name, Model, Tokenizer, model_path): def Mistral_gpu_model(self, Name, Model, Tokenizer, model_path):