From 5e25766855d87febc736b43cbefb52b6cf98e309 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 13 Jun 2024 17:37:58 +0800 Subject: [PATCH] fix and optimize chatglm2-32k and chatglm3-128k (#11306) --- .../llm/src/ipex_llm/transformers/convert.py | 17 +- .../ipex_llm/transformers/models/chatglm2.py | 51 +++++ .../transformers/models/chatglm2_32k.py | 206 ------------------ 3 files changed, 57 insertions(+), 217 deletions(-) delete mode 100644 python/llm/src/ipex_llm/transformers/models/chatglm2_32k.py diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index e631b674..746e2cac 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1008,26 +1008,21 @@ def _optimize_post(model, lightweight_bmm=False): if model.config.architectures is not None \ and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]: - if (model.config.num_layers == 28 and hasattr(model.config, 'rope_ratio') - and model.config.rope_ratio == 16): - # chatglm2-6b-32k - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.chatglm2_32k import chatglm2_32k_attention_forward - convert_forward(model, - module.SelfAttention, - chatglm2_32k_attention_forward) - elif hasattr(model.config, 'padded_vocab_size') and \ + if hasattr(model.config, 'padded_vocab_size') and \ model.config.padded_vocab_size == 65024: - # chatglm2-6b + # chatglm2-6b, chatglm2-6b-32k, chatglm3-6b, chatglm3-6b-32k, chatglm3-6b-128k modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.chatglm2 import chatglm2_attention_forward from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward + from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward convert_forward(model, module.SelfAttention, chatglm2_attention_forward) + convert_forward(model, + module.GLMTransformer, + chatglm2_encoder_forward) convert_forward(model, module.ChatGLMModel, chatglm2_model_forward) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 747b7ddd..375bfa47 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -145,6 +145,57 @@ def chatglm2_model_forward( ) +# remove code which stores first token's kv cache by tensor format +# to fix chatglm2-32k and chatglm3-128k +def chatglm2_encoder_forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, +): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.gradient_checkpointing and self.training: + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches[index], + use_cache + ) + else: + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + def chatglm2_attention_forward( self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True ): diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2_32k.py b/python/llm/src/ipex_llm/transformers/models/chatglm2_32k.py deleted file mode 100644 index 0df6ad1b..00000000 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2_32k.py +++ /dev/null @@ -1,206 +0,0 @@ -# -# Copyright 2016 The BigDL Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is adapted from -# https://huggingface.co/THUDM/chatglm2-6b-32k/blob/main/configuration_chatglm.py -# - -import torch -from typing import Optional, Tuple, Union, List, Callable, Dict, Any -import torch.nn.functional as F -from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache -from ipex_llm.transformers.models.chatglm2 import core_attn_forward_8eb45c - - -import os - -KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) -KV_CACHE_ALLOC_MIN_LENGTH = 512 - - -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - Returns: - A list of Tensors - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -@torch.jit.script -def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: - # x: [sq, b, np, hn] - sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) - rot_dim = rope_cache.shape[-2] * 2 - x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - # truncate to support variable sizes - rope_cache = rope_cache[:sq] - xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) - rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], - ], - -1, - ) - x_out2 = x_out2.flatten(3) - return torch.cat((x_out2, x_pass), dim=-1) - - -def chatglm2_32k_attention_forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True -): - # hidden_states: [sq, b, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - # ===================== - # Query, Key, and Value - # ===================== - - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - device = hidden_states.device - mixed_x_layer = self.query_key_value(hidden_states) - - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - [ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - ], - dim=-1, - ) - query_layer = query_layer.view( - query_layer.size()[:-1] + (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) - ) - key_layer = key_layer.view( - key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head) - ) - value_layer = value_layer.view( - value_layer.size()[:-1] - + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - else: - new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - - # apply relative positional encoding (rotary embedding) - if rotary_pos_emb is not None: - query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - - cur_length, batch_size = query_layer.shape[0], query_layer.shape[1] - - if self.multi_query_attention: - key_length = key_layer.size(0) - query_group_size = self.num_attention_heads_per_partition // \ - self.num_multi_query_groups_per_partition - key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn] - key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1) - key_layer = key_layer.contiguous().view((batch_size, - self.num_attention_heads_per_partition, - key_length, - self.hidden_size_per_attention_head)) - value_layer = value_layer.permute(1, 2, 0, 3).unsqueeze(-3) - value_layer = value_layer.expand(-1, -1, query_group_size, -1, -1) - value_layer = value_layer.contiguous().view((batch_size, - self.num_attention_heads_per_partition, - key_length, - self.hidden_size_per_attention_head)) - - # adjust key and value for inference - if kv_cache is not None: - cache_k, cache_v = kv_cache - cache_k = cache_k.permute(1, 2, 0, 3) - cache_v = cache_v.permute(1, 2, 0, 3) - past_length = cache_k.size(2) - - if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3): - max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH - new_cache_k, new_cache_v = extend_kv_cache(batch_size, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - past_length, - max_cache_length, - dtype=query_layer.dtype, - device=device) - new_cache_k[:] = cache_k - new_cache_v[:] = cache_v - cache_k = new_cache_k - cache_v = new_cache_v - - key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer) - elif use_cache: - max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ - + KV_CACHE_ALLOC_BLOCK_LENGTH - key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, cur_length, - max_cache_length, - dtype=query_layer.dtype, device=device) - key_cache[:] = key_layer - value_cache[:] = value_layer - key_layer = key_cache - value_layer = value_cache - - if use_cache: - key_layer = key_layer.permute(2, 0, 1, 3) - value_layer = value_layer.permute(2, 0, 1, 3) - if kv_cache is None: - kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), - value_layer.unsqueeze(0).unsqueeze(0)), dim=1) - else: - kv_cache = (key_layer, value_layer) - else: - kv_cache = None - - # ================================== - # core attention computation - # ================================== - - context_layer = core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask) - - # ================= - # Output. [sq, b, h] - # ================= - - output = self.dense(context_layer) - - return output, kv_cache