Add quantize kv cache support for chaglm2/3 (#9996)

This commit is contained in:
Yishuo Wang 2024-01-25 16:55:59 +08:00 committed by GitHub
parent 86055d76d5
commit bf65548d29
2 changed files with 158 additions and 10 deletions

View file

@ -637,17 +637,12 @@ def _optimize_post(model, lightweight_bmm=False):
# chatglm2-6b
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
from bigdl.llm.transformers.models.chatglm2 import chatglm2_model_forward
convert_forward(model,
module.SelfAttention,
chatglm2_attention_forward_8eb45c
)
convert_forward(model,
module.CoreAttention,
core_attn_forward_8eb45c)
chatglm2_attention_forward)
convert_forward(model,
module.ChatGLMModel,
chatglm2_model_forward)

View file

@ -23,6 +23,8 @@ from typing import Optional, Tuple, List
import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
append_fp8_kv_cache, restore_fp8_kv_cache, quantize_kv_cache
from bigdl.llm.transformers.models.utils import use_flash_attention
from bigdl.llm.transformers.models.llama import get_ipex_version
@ -78,6 +80,21 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t
return torch.cat((x_out2, x_pass), dim=-1)
def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Tensor, torch.Tensor):
# key, value's shape: [bs, n_kv_head, seq_len, head_dim] -> [bs, n_head, seq_len, head_dim]
batch_size, n_kv_head, seq_len, head_dim = key.shape
key = key.unsqueeze(2)
key = key.expand(-1, -1, n_head // n_kv_head, -1, -1)
key = key.contiguous().view(batch_size, n_head, seq_len, head_dim)
value = value.unsqueeze(2)
value = value.expand(-1, -1, n_head // n_kv_head, -1, -1)
value = value.contiguous().view(batch_size, n_head, seq_len, head_dim)
return key, value
def chatglm_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import linear_q4_0
@ -169,6 +186,142 @@ def chatglm2_model_forward(
)
def chatglm2_attention_forward(
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
):
if quantize_kv_cache(self.query_key_value, hidden_states):
forward_function = chatglm2_quantized_attention_forward_8eb45c
else:
forward_function = chatglm2_attention_forward_8eb45c
return forward_function(
self=self,
hidden_states=hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
kv_cache=kv_cache,
use_cache=use_cache
)
def chatglm2_quantized_attention_forward_8eb45c(
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
):
# hidden_states: [seq_len, bs, head_dim]
mixed_x_layer = self.query_key_value(hidden_states)
n_head = self.num_attention_heads_per_partition
n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head
head_dim = self.hidden_size_per_attention_head
query_layer, key_layer, value_layer = mixed_x_layer.split(
[n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim],
dim=-1,
)
query_layer = query_layer.view(query_layer.shape[:-1] + (n_head, head_dim))
key_layer = key_layer.view(key_layer.shape[:-1] + (n_kv_head, head_dim))
value_layer = value_layer.view(value_layer.shape[:-1] + (n_kv_head, head_dim))
# query, key, value's shape: [seq_len, bs, n_head/n_kv_head, head_dim]
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
if len(rotary_pos_emb) == 2 and isinstance(rotary_pos_emb, tuple):
# use_fuse_rope, see chatglm2_model_forward
cos, sin = rotary_pos_emb
rot_dim = cos.shape[-1]
query_layer = query_layer.transpose(0, 1)
key_layer = key_layer.transpose(0, 1)
query_layer_cur = query_layer[..., :rot_dim]
key_layer_cur = key_layer[..., :rot_dim]
# ipex's apply_rotary_embedding can change the origin storage, so query_layer will get
# the result directly.
torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur)
torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur)
query_layer = query_layer.transpose(0, 1)
key_layer = key_layer.transpose(0, 1)
else:
query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb)
query_layer = query_layer.permute(1, 2, 0, 3)
key_layer = key_layer.permute(1, 2, 0, 3)
value_layer = value_layer.permute(1, 2, 0, 3)
# query, key, value's shape: [bs, n_head/n_kv_head, seq_len, head_dim]
batch_size, _, seq_len, _ = query_layer.shape
if kv_cache is None:
# first token
if self.multi_query_attention:
key, value = repeat_kv(key_layer, value_layer, n_head)
else:
key, value = key_layer, value_layer
if attention_mask is None:
context_layer = F.scaled_dot_product_attention(query_layer, key, value, is_causal=True)
else:
context_layer = F.scaled_dot_product_attention(query_layer, key, value, attention_mask)
if use_cache:
k_cache, v_cache = init_fp8_kv_cache(batch_size,
n_kv_head,
head_dim,
0,
seq_len + KV_CACHE_ALLOC_MIN_LENGTH,
query_layer.device)
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
else:
k_cache, v_cache = kv_cache
k_cache = k_cache.permute(1, 2, 0, 3)
v_cache = v_cache.permute(1, 2, 0, 3)
# k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]
kv_seq_len = seq_len + k_cache.size(2)
if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
k_cache, v_cache = extend_fp8_kv_cache(
k_cache, v_cache,
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
device=query_layer.device,
)
if query_layer.device.type == 'xpu':
torch.xpu.empty_cache()
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
if seq_len != 1:
key, value = restore_fp8_kv_cache(k_cache, v_cache, query_layer.dtype)
key, value = repeat_kv(key, value, n_head)
attn = torch.matmul(query_layer, key.transpose(2, 3)) / math.sqrt(head_dim)
else:
key, value = k_cache, v_cache
import linear_q4_0
attn = linear_q4_0.query_key_fp8_matmul(query_layer, key) / math.sqrt(head_dim)
if attention_mask is not None:
attention_mask = ~attention_mask
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
device=query_layer.device)
if attention_mask.dtype == torch.bool:
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
else:
attn_bias += attention_mask
attn += attn_bias
attn = F.softmax(attn, dim=-1, dtype=torch.float32)
if seq_len != 1:
context_layer = torch.matmul(attn.to(value.dtype), value)
else:
import linear_q4_0
context_layer = linear_q4_0.attn_value_fp8_matmul(attn, value.transpose(-1, -2))
# context_layer's shape: [bs, n_head, seq_len, head_dim] -> [seq_len, bs, n_head * head_dim]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(seq_len, batch_size, -1)
if use_cache:
kv_cache = (k_cache.permute(2, 0, 1, 3), v_cache.permute(2, 0, 1, 3))
else:
kv_cache = None
output = self.dense(context_layer)
return output, kv_cache
def chatglm2_attention_forward_8eb45c(
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
):
@ -354,7 +507,7 @@ def chatglm2_attention_forward_8eb45c(
save_length,
self.hidden_size_per_attention_head))
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
context_layer = core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask)
# =================
# Output. [sq, b, h]
@ -365,7 +518,7 @@ def chatglm2_attention_forward_8eb45c(
return output, (cache_key_layer.permute(2, 0, 1, 3), cache_value_layer.permute(2, 0, 1, 3))
def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask):
def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask):
pytorch_major_version = int(torch.__version__.split('.')[0])
if pytorch_major_version >= 2:
query_layer = query_layer.permute(1, 2, 0, 3)
@ -392,7 +545,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
dtype=torch.float32).to(value_layer.dtype)
context_layer = torch.matmul(attn, value_layer)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.reshape(*new_context_layer_shape)
else:
# Raw attention scores