fix and optimize chatglm2-32k and chatglm3-128k (#11306)
This commit is contained in:
parent
60cb1dac7c
commit
5e25766855
3 changed files with 57 additions and 217 deletions
|
|
@ -1008,26 +1008,21 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
|
||||
if model.config.architectures is not None \
|
||||
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
|
||||
if (model.config.num_layers == 28 and hasattr(model.config, 'rope_ratio')
|
||||
and model.config.rope_ratio == 16):
|
||||
# chatglm2-6b-32k
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.chatglm2_32k import chatglm2_32k_attention_forward
|
||||
convert_forward(model,
|
||||
module.SelfAttention,
|
||||
chatglm2_32k_attention_forward)
|
||||
elif hasattr(model.config, 'padded_vocab_size') and \
|
||||
if hasattr(model.config, 'padded_vocab_size') and \
|
||||
model.config.padded_vocab_size == 65024:
|
||||
# chatglm2-6b
|
||||
# chatglm2-6b, chatglm2-6b-32k, chatglm3-6b, chatglm3-6b-32k, chatglm3-6b-128k
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.chatglm2 import chatglm2_attention_forward
|
||||
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
||||
from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
|
||||
from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward
|
||||
convert_forward(model,
|
||||
module.SelfAttention,
|
||||
chatglm2_attention_forward)
|
||||
convert_forward(model,
|
||||
module.GLMTransformer,
|
||||
chatglm2_encoder_forward)
|
||||
convert_forward(model,
|
||||
module.ChatGLMModel,
|
||||
chatglm2_model_forward)
|
||||
|
|
|
|||
|
|
@ -145,6 +145,57 @@ def chatglm2_model_forward(
|
|||
)
|
||||
|
||||
|
||||
# remove code which stores first token's kv cache by tensor format
|
||||
# to fix chatglm2-32k and chatglm3-128k
|
||||
def chatglm2_encoder_forward(
|
||||
self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
|
||||
use_cache: Optional[bool] = True,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
):
|
||||
if not kv_caches:
|
||||
kv_caches = [None for _ in range(self.num_layers)]
|
||||
presents = () if use_cache else None
|
||||
if self.gradient_checkpointing and self.training:
|
||||
use_cache = False
|
||||
|
||||
all_self_attentions = None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for index in range(self.num_layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer = self._get_layer(index)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_ret = torch.utils.checkpoint.checkpoint(
|
||||
layer,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
rotary_pos_emb,
|
||||
kv_caches[index],
|
||||
use_cache
|
||||
)
|
||||
else:
|
||||
layer_ret = layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
rotary_pos_emb,
|
||||
kv_cache=kv_caches[index],
|
||||
use_cache=use_cache
|
||||
)
|
||||
hidden_states, kv_cache = layer_ret
|
||||
if use_cache:
|
||||
presents = presents + (kv_cache,)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
# Final layer norm.
|
||||
if self.post_layer_norm:
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states, presents, all_hidden_states, all_self_attentions
|
||||
|
||||
|
||||
def chatglm2_attention_forward(
|
||||
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,206 +0,0 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# This file is adapted from
|
||||
# https://huggingface.co/THUDM/chatglm2-6b-32k/blob/main/configuration_chatglm.py
|
||||
#
|
||||
|
||||
import torch
|
||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
||||
import torch.nn.functional as F
|
||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from ipex_llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
|
||||
|
||||
|
||||
import os
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||
KV_CACHE_ALLOC_MIN_LENGTH = 512
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor,
|
||||
num_partitions: int,
|
||||
contiguous_split_chunks: bool = False,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Split a tensor along its last dimension.
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
num_partitions: number of partitions to split the tensor
|
||||
contiguous_split_chunks: If True, make each chunk contiguous
|
||||
in memory.
|
||||
Returns:
|
||||
A list of Tensors
|
||||
"""
|
||||
# Get the size and dimension.
|
||||
last_dim = tensor.dim() - 1
|
||||
last_dim_size = tensor.size()[last_dim] // num_partitions
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
# Note: torch.split does not create contiguous tensors by default.
|
||||
if contiguous_split_chunks:
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
return tensor_list
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
||||
# x: [sq, b, np, hn]
|
||||
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
||||
rot_dim = rope_cache.shape[-2] * 2
|
||||
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
||||
# truncate to support variable sizes
|
||||
rope_cache = rope_cache[:sq]
|
||||
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
|
||||
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
|
||||
x_out2 = torch.stack(
|
||||
[
|
||||
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
||||
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
||||
],
|
||||
-1,
|
||||
)
|
||||
x_out2 = x_out2.flatten(3)
|
||||
return torch.cat((x_out2, x_pass), dim=-1)
|
||||
|
||||
|
||||
def chatglm2_32k_attention_forward(
|
||||
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
||||
):
|
||||
# hidden_states: [sq, b, h]
|
||||
|
||||
# =================================================
|
||||
# Pre-allocate memory for key-values for inference.
|
||||
# =================================================
|
||||
# =====================
|
||||
# Query, Key, and Value
|
||||
# =====================
|
||||
|
||||
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||
device = hidden_states.device
|
||||
mixed_x_layer = self.query_key_value(hidden_states)
|
||||
|
||||
if self.multi_query_attention:
|
||||
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
||||
[
|
||||
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
|
||||
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
||||
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
query_layer = query_layer.view(
|
||||
query_layer.size()[:-1] + (self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
)
|
||||
key_layer = key_layer.view(
|
||||
key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
)
|
||||
value_layer = value_layer.view(
|
||||
value_layer.size()[:-1]
|
||||
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
||||
)
|
||||
else:
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head)
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
|
||||
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
||||
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||
|
||||
# apply relative positional encoding (rotary embedding)
|
||||
if rotary_pos_emb is not None:
|
||||
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
||||
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
||||
|
||||
cur_length, batch_size = query_layer.shape[0], query_layer.shape[1]
|
||||
|
||||
if self.multi_query_attention:
|
||||
key_length = key_layer.size(0)
|
||||
query_group_size = self.num_attention_heads_per_partition // \
|
||||
self.num_multi_query_groups_per_partition
|
||||
key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn]
|
||||
key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1)
|
||||
key_layer = key_layer.contiguous().view((batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
key_length,
|
||||
self.hidden_size_per_attention_head))
|
||||
value_layer = value_layer.permute(1, 2, 0, 3).unsqueeze(-3)
|
||||
value_layer = value_layer.expand(-1, -1, query_group_size, -1, -1)
|
||||
value_layer = value_layer.contiguous().view((batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
key_length,
|
||||
self.hidden_size_per_attention_head))
|
||||
|
||||
# adjust key and value for inference
|
||||
if kv_cache is not None:
|
||||
cache_k, cache_v = kv_cache
|
||||
cache_k = cache_k.permute(1, 2, 0, 3)
|
||||
cache_v = cache_v.permute(1, 2, 0, 3)
|
||||
past_length = cache_k.size(2)
|
||||
|
||||
if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3):
|
||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
past_length,
|
||||
max_cache_length,
|
||||
dtype=query_layer.dtype,
|
||||
device=device)
|
||||
new_cache_k[:] = cache_k
|
||||
new_cache_v[:] = cache_v
|
||||
cache_k = new_cache_k
|
||||
cache_v = new_cache_v
|
||||
|
||||
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
|
||||
elif use_cache:
|
||||
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
||||
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head, cur_length,
|
||||
max_cache_length,
|
||||
dtype=query_layer.dtype, device=device)
|
||||
key_cache[:] = key_layer
|
||||
value_cache[:] = value_layer
|
||||
key_layer = key_cache
|
||||
value_layer = value_cache
|
||||
|
||||
if use_cache:
|
||||
key_layer = key_layer.permute(2, 0, 1, 3)
|
||||
value_layer = value_layer.permute(2, 0, 1, 3)
|
||||
if kv_cache is None:
|
||||
kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0),
|
||||
value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
|
||||
else:
|
||||
kv_cache = (key_layer, value_layer)
|
||||
else:
|
||||
kv_cache = None
|
||||
|
||||
# ==================================
|
||||
# core attention computation
|
||||
# ==================================
|
||||
|
||||
context_layer = core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
# =================
|
||||
|
||||
output = self.dense(context_layer)
|
||||
|
||||
return output, kv_cache
|
||||
Loading…
Reference in a new issue