Add quantize kv cache support for chaglm2/3 (#9996)
This commit is contained in:
parent
86055d76d5
commit
bf65548d29
2 changed files with 158 additions and 10 deletions
|
|
@ -637,17 +637,12 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
# chatglm2-6b
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
|
||||
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
|
||||
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward
|
||||
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
||||
from bigdl.llm.transformers.models.chatglm2 import chatglm2_model_forward
|
||||
convert_forward(model,
|
||||
module.SelfAttention,
|
||||
chatglm2_attention_forward_8eb45c
|
||||
)
|
||||
convert_forward(model,
|
||||
module.CoreAttention,
|
||||
core_attn_forward_8eb45c)
|
||||
chatglm2_attention_forward)
|
||||
convert_forward(model,
|
||||
module.ChatGLMModel,
|
||||
chatglm2_model_forward)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,8 @@ from typing import Optional, Tuple, List
|
|||
import torch.nn.functional as F
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
|
||||
append_fp8_kv_cache, restore_fp8_kv_cache, quantize_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import use_flash_attention
|
||||
from bigdl.llm.transformers.models.llama import get_ipex_version
|
||||
|
||||
|
|
@ -78,6 +80,21 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t
|
|||
return torch.cat((x_out2, x_pass), dim=-1)
|
||||
|
||||
|
||||
def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Tensor, torch.Tensor):
|
||||
# key, value's shape: [bs, n_kv_head, seq_len, head_dim] -> [bs, n_head, seq_len, head_dim]
|
||||
batch_size, n_kv_head, seq_len, head_dim = key.shape
|
||||
|
||||
key = key.unsqueeze(2)
|
||||
key = key.expand(-1, -1, n_head // n_kv_head, -1, -1)
|
||||
key = key.contiguous().view(batch_size, n_head, seq_len, head_dim)
|
||||
|
||||
value = value.unsqueeze(2)
|
||||
value = value.expand(-1, -1, n_head // n_kv_head, -1, -1)
|
||||
value = value.contiguous().view(batch_size, n_head, seq_len, head_dim)
|
||||
|
||||
return key, value
|
||||
|
||||
|
||||
def chatglm_rms_norm_forward(self, hidden_states):
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
import linear_q4_0
|
||||
|
|
@ -169,6 +186,142 @@ def chatglm2_model_forward(
|
|||
)
|
||||
|
||||
|
||||
def chatglm2_attention_forward(
|
||||
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
||||
):
|
||||
if quantize_kv_cache(self.query_key_value, hidden_states):
|
||||
forward_function = chatglm2_quantized_attention_forward_8eb45c
|
||||
else:
|
||||
forward_function = chatglm2_attention_forward_8eb45c
|
||||
return forward_function(
|
||||
self=self,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
kv_cache=kv_cache,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
|
||||
def chatglm2_quantized_attention_forward_8eb45c(
|
||||
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
||||
):
|
||||
# hidden_states: [seq_len, bs, head_dim]
|
||||
mixed_x_layer = self.query_key_value(hidden_states)
|
||||
|
||||
n_head = self.num_attention_heads_per_partition
|
||||
n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head
|
||||
head_dim = self.hidden_size_per_attention_head
|
||||
|
||||
query_layer, key_layer, value_layer = mixed_x_layer.split(
|
||||
[n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
query_layer = query_layer.view(query_layer.shape[:-1] + (n_head, head_dim))
|
||||
key_layer = key_layer.view(key_layer.shape[:-1] + (n_kv_head, head_dim))
|
||||
value_layer = value_layer.view(value_layer.shape[:-1] + (n_kv_head, head_dim))
|
||||
# query, key, value's shape: [seq_len, bs, n_head/n_kv_head, head_dim]
|
||||
|
||||
# apply relative positional encoding (rotary embedding)
|
||||
if rotary_pos_emb is not None:
|
||||
if len(rotary_pos_emb) == 2 and isinstance(rotary_pos_emb, tuple):
|
||||
# use_fuse_rope, see chatglm2_model_forward
|
||||
cos, sin = rotary_pos_emb
|
||||
rot_dim = cos.shape[-1]
|
||||
query_layer = query_layer.transpose(0, 1)
|
||||
key_layer = key_layer.transpose(0, 1)
|
||||
query_layer_cur = query_layer[..., :rot_dim]
|
||||
key_layer_cur = key_layer[..., :rot_dim]
|
||||
# ipex's apply_rotary_embedding can change the origin storage, so query_layer will get
|
||||
# the result directly.
|
||||
torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur)
|
||||
torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur)
|
||||
query_layer = query_layer.transpose(0, 1)
|
||||
key_layer = key_layer.transpose(0, 1)
|
||||
else:
|
||||
query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb)
|
||||
key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb)
|
||||
|
||||
query_layer = query_layer.permute(1, 2, 0, 3)
|
||||
key_layer = key_layer.permute(1, 2, 0, 3)
|
||||
value_layer = value_layer.permute(1, 2, 0, 3)
|
||||
# query, key, value's shape: [bs, n_head/n_kv_head, seq_len, head_dim]
|
||||
batch_size, _, seq_len, _ = query_layer.shape
|
||||
|
||||
if kv_cache is None:
|
||||
# first token
|
||||
if self.multi_query_attention:
|
||||
key, value = repeat_kv(key_layer, value_layer, n_head)
|
||||
else:
|
||||
key, value = key_layer, value_layer
|
||||
|
||||
if attention_mask is None:
|
||||
context_layer = F.scaled_dot_product_attention(query_layer, key, value, is_causal=True)
|
||||
else:
|
||||
context_layer = F.scaled_dot_product_attention(query_layer, key, value, attention_mask)
|
||||
|
||||
if use_cache:
|
||||
k_cache, v_cache = init_fp8_kv_cache(batch_size,
|
||||
n_kv_head,
|
||||
head_dim,
|
||||
0,
|
||||
seq_len + KV_CACHE_ALLOC_MIN_LENGTH,
|
||||
query_layer.device)
|
||||
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
|
||||
else:
|
||||
k_cache, v_cache = kv_cache
|
||||
k_cache = k_cache.permute(1, 2, 0, 3)
|
||||
v_cache = v_cache.permute(1, 2, 0, 3)
|
||||
# k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]
|
||||
|
||||
kv_seq_len = seq_len + k_cache.size(2)
|
||||
if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
|
||||
k_cache, v_cache = extend_fp8_kv_cache(
|
||||
k_cache, v_cache,
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
device=query_layer.device,
|
||||
)
|
||||
if query_layer.device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
|
||||
|
||||
if seq_len != 1:
|
||||
key, value = restore_fp8_kv_cache(k_cache, v_cache, query_layer.dtype)
|
||||
key, value = repeat_kv(key, value, n_head)
|
||||
attn = torch.matmul(query_layer, key.transpose(2, 3)) / math.sqrt(head_dim)
|
||||
else:
|
||||
key, value = k_cache, v_cache
|
||||
import linear_q4_0
|
||||
attn = linear_q4_0.query_key_fp8_matmul(query_layer, key) / math.sqrt(head_dim)
|
||||
if attention_mask is not None:
|
||||
attention_mask = ~attention_mask
|
||||
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
|
||||
device=query_layer.device)
|
||||
if attention_mask.dtype == torch.bool:
|
||||
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
||||
else:
|
||||
attn_bias += attention_mask
|
||||
attn += attn_bias
|
||||
attn = F.softmax(attn, dim=-1, dtype=torch.float32)
|
||||
if seq_len != 1:
|
||||
context_layer = torch.matmul(attn.to(value.dtype), value)
|
||||
else:
|
||||
import linear_q4_0
|
||||
context_layer = linear_q4_0.attn_value_fp8_matmul(attn, value.transpose(-1, -2))
|
||||
|
||||
# context_layer's shape: [bs, n_head, seq_len, head_dim] -> [seq_len, bs, n_head * head_dim]
|
||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(seq_len, batch_size, -1)
|
||||
|
||||
if use_cache:
|
||||
kv_cache = (k_cache.permute(2, 0, 1, 3), v_cache.permute(2, 0, 1, 3))
|
||||
else:
|
||||
kv_cache = None
|
||||
|
||||
output = self.dense(context_layer)
|
||||
|
||||
return output, kv_cache
|
||||
|
||||
|
||||
def chatglm2_attention_forward_8eb45c(
|
||||
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
||||
):
|
||||
|
|
@ -354,7 +507,7 @@ def chatglm2_attention_forward_8eb45c(
|
|||
save_length,
|
||||
self.hidden_size_per_attention_head))
|
||||
|
||||
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
||||
context_layer = core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
|
|
@ -365,7 +518,7 @@ def chatglm2_attention_forward_8eb45c(
|
|||
return output, (cache_key_layer.permute(2, 0, 1, 3), cache_value_layer.permute(2, 0, 1, 3))
|
||||
|
||||
|
||||
def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask):
|
||||
def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask):
|
||||
pytorch_major_version = int(torch.__version__.split('.')[0])
|
||||
if pytorch_major_version >= 2:
|
||||
query_layer = query_layer.permute(1, 2, 0, 3)
|
||||
|
|
@ -392,7 +545,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
|||
dtype=torch.float32).to(value_layer.dtype)
|
||||
context_layer = torch.matmul(attn, value_layer)
|
||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
||||
context_layer = context_layer.reshape(*new_context_layer_shape)
|
||||
else:
|
||||
# Raw attention scores
|
||||
|
|
|
|||
Loading…
Reference in a new issue