diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 45a23ed4..d969ac05 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -637,17 +637,12 @@ def _optimize_post(model, lightweight_bmm=False): # chatglm2-6b modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) - from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c - from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c + from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward from bigdl.llm.transformers.models.chatglm2 import chatglm2_model_forward convert_forward(model, module.SelfAttention, - chatglm2_attention_forward_8eb45c - ) - convert_forward(model, - module.CoreAttention, - core_attn_forward_8eb45c) + chatglm2_attention_forward) convert_forward(model, module.ChatGLMModel, chatglm2_model_forward) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index f410b8fc..8944d807 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -23,6 +23,8 @@ from typing import Optional, Tuple, List import torch.nn.functional as F from transformers.modeling_outputs import BaseModelOutputWithPast from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \ + append_fp8_kv_cache, restore_fp8_kv_cache, quantize_kv_cache from bigdl.llm.transformers.models.utils import use_flash_attention from bigdl.llm.transformers.models.llama import get_ipex_version @@ -78,6 +80,21 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t return torch.cat((x_out2, x_pass), dim=-1) +def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Tensor, torch.Tensor): + # key, value's shape: [bs, n_kv_head, seq_len, head_dim] -> [bs, n_head, seq_len, head_dim] + batch_size, n_kv_head, seq_len, head_dim = key.shape + + key = key.unsqueeze(2) + key = key.expand(-1, -1, n_head // n_kv_head, -1, -1) + key = key.contiguous().view(batch_size, n_head, seq_len, head_dim) + + value = value.unsqueeze(2) + value = value.expand(-1, -1, n_head // n_kv_head, -1, -1) + value = value.contiguous().view(batch_size, n_head, seq_len, head_dim) + + return key, value + + def chatglm_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): import linear_q4_0 @@ -169,6 +186,142 @@ def chatglm2_model_forward( ) +def chatglm2_attention_forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True +): + if quantize_kv_cache(self.query_key_value, hidden_states): + forward_function = chatglm2_quantized_attention_forward_8eb45c + else: + forward_function = chatglm2_attention_forward_8eb45c + return forward_function( + self=self, + hidden_states=hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache + ) + + +def chatglm2_quantized_attention_forward_8eb45c( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True +): + # hidden_states: [seq_len, bs, head_dim] + mixed_x_layer = self.query_key_value(hidden_states) + + n_head = self.num_attention_heads_per_partition + n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head + head_dim = self.hidden_size_per_attention_head + + query_layer, key_layer, value_layer = mixed_x_layer.split( + [n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], + dim=-1, + ) + query_layer = query_layer.view(query_layer.shape[:-1] + (n_head, head_dim)) + key_layer = key_layer.view(key_layer.shape[:-1] + (n_kv_head, head_dim)) + value_layer = value_layer.view(value_layer.shape[:-1] + (n_kv_head, head_dim)) + # query, key, value's shape: [seq_len, bs, n_head/n_kv_head, head_dim] + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + if len(rotary_pos_emb) == 2 and isinstance(rotary_pos_emb, tuple): + # use_fuse_rope, see chatglm2_model_forward + cos, sin = rotary_pos_emb + rot_dim = cos.shape[-1] + query_layer = query_layer.transpose(0, 1) + key_layer = key_layer.transpose(0, 1) + query_layer_cur = query_layer[..., :rot_dim] + key_layer_cur = key_layer[..., :rot_dim] + # ipex's apply_rotary_embedding can change the origin storage, so query_layer will get + # the result directly. + torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur) + torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur) + query_layer = query_layer.transpose(0, 1) + key_layer = key_layer.transpose(0, 1) + else: + query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb) + + query_layer = query_layer.permute(1, 2, 0, 3) + key_layer = key_layer.permute(1, 2, 0, 3) + value_layer = value_layer.permute(1, 2, 0, 3) + # query, key, value's shape: [bs, n_head/n_kv_head, seq_len, head_dim] + batch_size, _, seq_len, _ = query_layer.shape + + if kv_cache is None: + # first token + if self.multi_query_attention: + key, value = repeat_kv(key_layer, value_layer, n_head) + else: + key, value = key_layer, value_layer + + if attention_mask is None: + context_layer = F.scaled_dot_product_attention(query_layer, key, value, is_causal=True) + else: + context_layer = F.scaled_dot_product_attention(query_layer, key, value, attention_mask) + + if use_cache: + k_cache, v_cache = init_fp8_kv_cache(batch_size, + n_kv_head, + head_dim, + 0, + seq_len + KV_CACHE_ALLOC_MIN_LENGTH, + query_layer.device) + k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer) + else: + k_cache, v_cache = kv_cache + k_cache = k_cache.permute(1, 2, 0, 3) + v_cache = v_cache.permute(1, 2, 0, 3) + # k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim] + + kv_seq_len = seq_len + k_cache.size(2) + if k_cache.stride(1) < kv_seq_len * k_cache.size(3): + k_cache, v_cache = extend_fp8_kv_cache( + k_cache, v_cache, + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + device=query_layer.device, + ) + if query_layer.device.type == 'xpu': + torch.xpu.empty_cache() + k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer) + + if seq_len != 1: + key, value = restore_fp8_kv_cache(k_cache, v_cache, query_layer.dtype) + key, value = repeat_kv(key, value, n_head) + attn = torch.matmul(query_layer, key.transpose(2, 3)) / math.sqrt(head_dim) + else: + key, value = k_cache, v_cache + import linear_q4_0 + attn = linear_q4_0.query_key_fp8_matmul(query_layer, key) / math.sqrt(head_dim) + if attention_mask is not None: + attention_mask = ~attention_mask + attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, + device=query_layer.device) + if attention_mask.dtype == torch.bool: + attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) + else: + attn_bias += attention_mask + attn += attn_bias + attn = F.softmax(attn, dim=-1, dtype=torch.float32) + if seq_len != 1: + context_layer = torch.matmul(attn.to(value.dtype), value) + else: + import linear_q4_0 + context_layer = linear_q4_0.attn_value_fp8_matmul(attn, value.transpose(-1, -2)) + + # context_layer's shape: [bs, n_head, seq_len, head_dim] -> [seq_len, bs, n_head * head_dim] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(seq_len, batch_size, -1) + + if use_cache: + kv_cache = (k_cache.permute(2, 0, 1, 3), v_cache.permute(2, 0, 1, 3)) + else: + kv_cache = None + + output = self.dense(context_layer) + + return output, kv_cache + + def chatglm2_attention_forward_8eb45c( self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True ): @@ -354,7 +507,7 @@ def chatglm2_attention_forward_8eb45c( save_length, self.hidden_size_per_attention_head)) - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + context_layer = core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask) # ================= # Output. [sq, b, h] @@ -365,7 +518,7 @@ def chatglm2_attention_forward_8eb45c( return output, (cache_key_layer.permute(2, 0, 1, 3), cache_value_layer.permute(2, 0, 1, 3)) -def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask): +def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask): pytorch_major_version = int(torch.__version__.split('.')[0]) if pytorch_major_version >= 2: query_layer = query_layer.permute(1, 2, 0, 3) @@ -392,7 +545,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio dtype=torch.float32).to(value_layer.dtype) context_layer = torch.matmul(attn, value_layer) context_layer = context_layer.permute(2, 0, 1, 3) - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.reshape(*new_context_layer_shape) else: # Raw attention scores