optimize chatglm2 long sequence (#8662)
* add chatglm2 * optimize a little * optimize chatglm long sequence * fix style * address comments and fix style * fix bug
This commit is contained in:
parent
3407f87075
commit
b6468bac43
3 changed files with 335 additions and 24 deletions
|
|
@ -40,6 +40,8 @@ import torch.nn as nn
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from bigdl.llm.transformers.linear_quant import LinearQuant, ParamsQuant
|
from bigdl.llm.transformers.linear_quant import LinearQuant, ParamsQuant
|
||||||
import warnings
|
import warnings
|
||||||
|
import transformers
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
|
def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
|
||||||
|
|
@ -124,4 +126,42 @@ def ggml_convert_quant(model, qtype, convert_shape_only=False):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model.to(torch.float32)
|
model.to(torch.float32)
|
||||||
|
|
||||||
|
model = optimize(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def convert_forward(m, target_m, new_forward):
|
||||||
|
for _, sub_m in m.named_children():
|
||||||
|
if isinstance(sub_m, target_m):
|
||||||
|
bound_method = new_forward.__get__(sub_m, sub_m.__class__)
|
||||||
|
setattr(sub_m, "forward", bound_method)
|
||||||
|
convert_forward(sub_m, target_m, new_forward)
|
||||||
|
|
||||||
|
|
||||||
|
def optimize(model):
|
||||||
|
from packaging import version
|
||||||
|
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
||||||
|
trans_version = transformers.__version__
|
||||||
|
if version.parse(trans_version) >= version.parse("4.31.0"):
|
||||||
|
convert_forward(
|
||||||
|
model,
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||||
|
llama_attention_forward_4_31,)
|
||||||
|
else:
|
||||||
|
# todo implement 4.28.0 ~ 4.30.2
|
||||||
|
pass
|
||||||
|
|
||||||
|
if "chatglm2" in model.config._name_or_path:
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward_8eb45c
|
||||||
|
from bigdl.llm.transformers.models.chatglm import core_attn_forward_8eb45c
|
||||||
|
convert_forward(model,
|
||||||
|
module.SelfAttention,
|
||||||
|
chatglm_attention_forward_8eb45c
|
||||||
|
)
|
||||||
|
convert_forward(model,
|
||||||
|
module.CoreAttention,
|
||||||
|
core_attn_forward_8eb45c)
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,8 @@ from .utils import extract_local_archive_file, \
|
||||||
fix_key
|
fix_key
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.utils.common import invalidInputError, MuteHFLogger
|
from bigdl.llm.utils.common import invalidInputError, MuteHFLogger
|
||||||
|
import sys
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
def save_low_bit(self, *args, **kwargs):
|
def save_low_bit(self, *args, **kwargs):
|
||||||
|
|
@ -33,14 +35,6 @@ def save_low_bit(self, *args, **kwargs):
|
||||||
self.save_pretrained(*args, **kwargs)
|
self.save_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def convert_forward(m, target_m, new_forward):
|
|
||||||
for _, sub_m in m.named_children():
|
|
||||||
if isinstance(sub_m, target_m):
|
|
||||||
bound_method = new_forward.__get__(sub_m, sub_m.__class__)
|
|
||||||
setattr(sub_m, "forward", bound_method)
|
|
||||||
convert_forward(sub_m, target_m, new_forward)
|
|
||||||
|
|
||||||
|
|
||||||
class _BaseAutoModelClass:
|
class _BaseAutoModelClass:
|
||||||
|
|
||||||
HF_MODEL = None
|
HF_MODEL = None
|
||||||
|
|
@ -91,20 +85,6 @@ class _BaseAutoModelClass:
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def optimize(cls, model):
|
|
||||||
from packaging import version
|
|
||||||
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
|
||||||
trans_version = transformers.__version__
|
|
||||||
if version.parse(trans_version) >= version.parse("4.31.0"):
|
|
||||||
convert_forward(
|
|
||||||
model,
|
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
|
||||||
llama_attention_forward_4_31,)
|
|
||||||
else:
|
|
||||||
# todo implement 4.28.0 ~ 4.30.2
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_convert(cls, q_k, *args, **kwargs):
|
def load_convert(cls, q_k, *args, **kwargs):
|
||||||
from .convert import ggml_convert_quant
|
from .convert import ggml_convert_quant
|
||||||
|
|
@ -117,8 +97,6 @@ class _BaseAutoModelClass:
|
||||||
model = ggml_convert_quant(model, qtype)
|
model = ggml_convert_quant(model, qtype)
|
||||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||||
|
|
||||||
cls.optimize(model)
|
|
||||||
|
|
||||||
# add save_low_bit to pretrained model dynamically
|
# add save_low_bit to pretrained model dynamically
|
||||||
import types
|
import types
|
||||||
model.save_low_bit = types.MethodType(save_low_bit, model)
|
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||||
|
|
|
||||||
293
python/llm/src/bigdl/llm/transformers/models/chatglm.py
Normal file
293
python/llm/src/bigdl/llm/transformers/models/chatglm.py
Normal file
|
|
@ -0,0 +1,293 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# This file is adapted from
|
||||||
|
# https://huggingface.co/THUDM/chatglm2-6b/blob/8eb45c842594b8473f291d0f94e7bbe86ffc67d8/modeling_chatglm.py
|
||||||
|
#
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
KV_CACHE_ALLOC_MIN_LENGTH = 512
|
||||||
|
|
||||||
|
|
||||||
|
def split_tensor_along_last_dim(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
num_partitions: int,
|
||||||
|
contiguous_split_chunks: bool = False,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Split a tensor along its last dimension.
|
||||||
|
Arguments:
|
||||||
|
tensor: input tensor.
|
||||||
|
num_partitions: number of partitions to split the tensor
|
||||||
|
contiguous_split_chunks: If True, make each chunk contiguous
|
||||||
|
in memory.
|
||||||
|
Returns:
|
||||||
|
A list of Tensors
|
||||||
|
"""
|
||||||
|
# Get the size and dimension.
|
||||||
|
last_dim = tensor.dim() - 1
|
||||||
|
last_dim_size = tensor.size()[last_dim] // num_partitions
|
||||||
|
# Split.
|
||||||
|
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||||
|
# Note: torch.split does not create contiguous tensors by default.
|
||||||
|
if contiguous_split_chunks:
|
||||||
|
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||||
|
|
||||||
|
return tensor_list
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
||||||
|
# x: [sq, b, np, hn]
|
||||||
|
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
||||||
|
rot_dim = rope_cache.shape[-2] * 2
|
||||||
|
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
||||||
|
# truncate to support variable sizes
|
||||||
|
rope_cache = rope_cache[:sq]
|
||||||
|
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
|
||||||
|
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
|
||||||
|
x_out2 = torch.stack(
|
||||||
|
[
|
||||||
|
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
||||||
|
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
||||||
|
],
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
x_out2 = x_out2.flatten(3)
|
||||||
|
return torch.cat((x_out2, x_pass), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def chatglm_attention_forward_8eb45c(
|
||||||
|
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
||||||
|
):
|
||||||
|
# hidden_states: [sq, b, h]
|
||||||
|
|
||||||
|
# =================================================
|
||||||
|
# Pre-allocate memory for key-values for inference.
|
||||||
|
# =================================================
|
||||||
|
# =====================
|
||||||
|
# Query, Key, and Value
|
||||||
|
# =====================
|
||||||
|
|
||||||
|
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||||
|
mixed_x_layer = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
|
if self.multi_query_attention:
|
||||||
|
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
||||||
|
[
|
||||||
|
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
|
||||||
|
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
||||||
|
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
query_layer = query_layer.view(
|
||||||
|
query_layer.size()[:-1] + (self.num_attention_heads_per_partition,
|
||||||
|
self.hidden_size_per_attention_head)
|
||||||
|
)
|
||||||
|
key_layer = key_layer.view(
|
||||||
|
key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition,
|
||||||
|
self.hidden_size_per_attention_head)
|
||||||
|
)
|
||||||
|
value_layer = value_layer.view(
|
||||||
|
value_layer.size()[:-1]
|
||||||
|
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition,
|
||||||
|
3 * self.hidden_size_per_attention_head)
|
||||||
|
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||||
|
|
||||||
|
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
||||||
|
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||||
|
|
||||||
|
# apply relative positional encoding (rotary embedding)
|
||||||
|
if rotary_pos_emb is not None:
|
||||||
|
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
||||||
|
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
||||||
|
|
||||||
|
cur_length, batch_size = query_layer.shape[0], query_layer.shape[1]
|
||||||
|
|
||||||
|
if self.multi_query_attention:
|
||||||
|
key_length = key_layer.size(0)
|
||||||
|
query_group_size = self.num_attention_heads_per_partition // \
|
||||||
|
self.num_multi_query_groups_per_partition
|
||||||
|
key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn]
|
||||||
|
key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1)
|
||||||
|
key_layer = key_layer.contiguous().view((batch_size,
|
||||||
|
self.num_attention_heads_per_partition,
|
||||||
|
key_length,
|
||||||
|
self.hidden_size_per_attention_head))
|
||||||
|
value_layer = value_layer.permute(1, 2, 0, 3).unsqueeze(-3)
|
||||||
|
value_layer = value_layer.expand(-1, -1, query_group_size, -1, -1)
|
||||||
|
value_layer = value_layer.contiguous().view((batch_size,
|
||||||
|
self.num_attention_heads_per_partition,
|
||||||
|
key_length,
|
||||||
|
self.hidden_size_per_attention_head))
|
||||||
|
|
||||||
|
# adjust key and value for inference
|
||||||
|
if kv_cache is not None:
|
||||||
|
cache_k, cache_v = kv_cache
|
||||||
|
past_length = cache_k.size(2)
|
||||||
|
|
||||||
|
if past_length + cur_length > self.max_cache_length:
|
||||||
|
self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
|
self.kv_cache = (torch.empty(batch_size,
|
||||||
|
self.num_attention_heads_per_partition,
|
||||||
|
self.max_cache_length,
|
||||||
|
self.hidden_size_per_attention_head,),
|
||||||
|
torch.empty(batch_size,
|
||||||
|
self.num_attention_heads_per_partition,
|
||||||
|
self.max_cache_length,
|
||||||
|
self.hidden_size_per_attention_head,))
|
||||||
|
self.kv_cache[0][:, :, :past_length, :] = cache_k
|
||||||
|
self.kv_cache[1][:, :, :past_length, :] = cache_v
|
||||||
|
self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer
|
||||||
|
self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer
|
||||||
|
|
||||||
|
key_layer = self.kv_cache[0][:, :, :past_length + cur_length, :]
|
||||||
|
value_layer = self.kv_cache[1][:, :, :past_length + cur_length, :]
|
||||||
|
|
||||||
|
elif use_cache:
|
||||||
|
self.max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
||||||
|
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
|
self.kv_cache = (torch.empty(batch_size, self.num_attention_heads_per_partition,
|
||||||
|
self.max_cache_length, self.hidden_size_per_attention_head,),
|
||||||
|
torch.empty(batch_size, self.num_attention_heads_per_partition,
|
||||||
|
self.max_cache_length, self.hidden_size_per_attention_head,))
|
||||||
|
self.kv_cache[0][:, :, :cur_length, :] = key_layer
|
||||||
|
self.kv_cache[1][:, :, :cur_length, :] = value_layer
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
kv_cache = (key_layer, value_layer)
|
||||||
|
else:
|
||||||
|
kv_cache = None
|
||||||
|
|
||||||
|
# ==================================
|
||||||
|
# core attention computation
|
||||||
|
# ==================================
|
||||||
|
|
||||||
|
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
||||||
|
|
||||||
|
# =================
|
||||||
|
# Output. [sq, b, h]
|
||||||
|
# =================
|
||||||
|
|
||||||
|
output = self.dense(context_layer)
|
||||||
|
|
||||||
|
return output, kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask):
|
||||||
|
pytorch_major_version = int(torch.__version__.split('.')[0])
|
||||||
|
if query_layer.size(0) > 1 and pytorch_major_version >= 2:
|
||||||
|
query_layer = query_layer.permute(1, 2, 0, 3)
|
||||||
|
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
||||||
|
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
is_causal=True)
|
||||||
|
else:
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = ~attention_mask
|
||||||
|
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
attention_mask)
|
||||||
|
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||||
|
context_layer = context_layer.reshape(*new_context_layer_shape)
|
||||||
|
else:
|
||||||
|
# Raw attention scores
|
||||||
|
|
||||||
|
# [b, np, sq, sk]
|
||||||
|
output_size = (query_layer.size(1), query_layer.size(2),
|
||||||
|
query_layer.size(0), key_layer.size(2))
|
||||||
|
|
||||||
|
# [sq, b, np, hn] -> [sq, b * np, hn]
|
||||||
|
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
||||||
|
# [sk, b, np, hn] -> [sk, b * np, hn]
|
||||||
|
key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
|
||||||
|
|
||||||
|
# preallocting input tensor: [b * np, sq, sk]
|
||||||
|
matmul_input_buffer = torch.empty(
|
||||||
|
output_size[0] * output_size[1],
|
||||||
|
output_size[2], output_size[3], dtype=query_layer.dtype,
|
||||||
|
device=query_layer.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Raw attention scores. [b * np, sq, sk]
|
||||||
|
matmul_result = torch.baddbmm(
|
||||||
|
matmul_input_buffer,
|
||||||
|
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
||||||
|
key_layer.transpose(1, 2), # [b * np, hn, sk]
|
||||||
|
beta=0.0,
|
||||||
|
alpha=(1.0 / self.norm_factor),
|
||||||
|
)
|
||||||
|
|
||||||
|
# change view to [b, np, sq, sk]
|
||||||
|
attention_scores = matmul_result.view(*output_size)
|
||||||
|
|
||||||
|
# ===========================
|
||||||
|
# Attention probs and dropout
|
||||||
|
# ===========================
|
||||||
|
|
||||||
|
# attention scores and attention mask [b, np, sq, sk]
|
||||||
|
if self.attention_softmax_in_fp32:
|
||||||
|
attention_scores = attention_scores.float()
|
||||||
|
if self.coeff is not None:
|
||||||
|
attention_scores = attention_scores * self.coeff
|
||||||
|
if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
|
||||||
|
attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
|
||||||
|
device=attention_scores.device, dtype=torch.bool)
|
||||||
|
attention_mask.tril_()
|
||||||
|
attention_mask = ~attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
|
||||||
|
attention_probs = F.softmax(attention_scores, dim=-1)
|
||||||
|
attention_probs = attention_probs.type_as(value_layer)
|
||||||
|
|
||||||
|
# This is actually dropping out entire tokens to attend to, which might
|
||||||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||||
|
attention_probs = self.attention_dropout(attention_probs)
|
||||||
|
# =========================
|
||||||
|
# Context layer. [sq, b, hp]
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
# value_layer -> context layer.
|
||||||
|
# [sk, b, np, hn] --> [b, np, sq, hn]
|
||||||
|
|
||||||
|
# context layer shape: [b, np, sq, hn]
|
||||||
|
output_size = (value_layer.size(0), value_layer.size(1),
|
||||||
|
query_layer.size(0), value_layer.size(3))
|
||||||
|
# change view [sk, b * np, hn]
|
||||||
|
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
|
||||||
|
# change view [b * np, sq, sk]
|
||||||
|
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
||||||
|
# matmul: [b * np, sq, hn]
|
||||||
|
context_layer = torch.bmm(attention_probs, value_layer)
|
||||||
|
# change view [b, np, sq, hn]
|
||||||
|
context_layer = context_layer.view(*output_size)
|
||||||
|
# [b, np, sq, hn] --> [sq, b, np, hn]
|
||||||
|
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||||
|
# [sq, b, np, hn] --> [sq, b, hp]
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||||
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
|
return context_layer
|
||||||
Loading…
Reference in a new issue