From b6468bac433baaafce4a3b93b821e64a95e9f97c Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Fri, 4 Aug 2023 08:56:24 +0800 Subject: [PATCH] optimize chatglm2 long sequence (#8662) * add chatglm2 * optimize a little * optimize chatglm long sequence * fix style * address comments and fix style * fix bug --- .../llm/src/bigdl/llm/transformers/convert.py | 40 +++ .../llm/src/bigdl/llm/transformers/model.py | 26 +- .../bigdl/llm/transformers/models/chatglm.py | 293 ++++++++++++++++++ 3 files changed, 335 insertions(+), 24 deletions(-) create mode 100644 python/llm/src/bigdl/llm/transformers/models/chatglm.py diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index df31020c..07fba239 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -40,6 +40,8 @@ import torch.nn as nn from accelerate import init_empty_weights from bigdl.llm.transformers.linear_quant import LinearQuant, ParamsQuant import warnings +import transformers +import importlib def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None, @@ -124,4 +126,42 @@ def ggml_convert_quant(model, qtype, convert_shape_only=False): ) else: model.to(torch.float32) + + model = optimize(model) + return model + + +def convert_forward(m, target_m, new_forward): + for _, sub_m in m.named_children(): + if isinstance(sub_m, target_m): + bound_method = new_forward.__get__(sub_m, sub_m.__class__) + setattr(sub_m, "forward", bound_method) + convert_forward(sub_m, target_m, new_forward) + + +def optimize(model): + from packaging import version + from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 + trans_version = transformers.__version__ + if version.parse(trans_version) >= version.parse("4.31.0"): + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaAttention, + llama_attention_forward_4_31,) + else: + # todo implement 4.28.0 ~ 4.30.2 + pass + + if "chatglm2" in model.config._name_or_path: + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward_8eb45c + from bigdl.llm.transformers.models.chatglm import core_attn_forward_8eb45c + convert_forward(model, + module.SelfAttention, + chatglm_attention_forward_8eb45c + ) + convert_forward(model, + module.CoreAttention, + core_attn_forward_8eb45c) return model diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index d417203b..8c198fe6 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -24,6 +24,8 @@ from .utils import extract_local_archive_file, \ fix_key from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError, MuteHFLogger +import sys +import importlib def save_low_bit(self, *args, **kwargs): @@ -33,14 +35,6 @@ def save_low_bit(self, *args, **kwargs): self.save_pretrained(*args, **kwargs) -def convert_forward(m, target_m, new_forward): - for _, sub_m in m.named_children(): - if isinstance(sub_m, target_m): - bound_method = new_forward.__get__(sub_m, sub_m.__class__) - setattr(sub_m, "forward", bound_method) - convert_forward(sub_m, target_m, new_forward) - - class _BaseAutoModelClass: HF_MODEL = None @@ -91,20 +85,6 @@ class _BaseAutoModelClass: return model - @classmethod - def optimize(cls, model): - from packaging import version - from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 - trans_version = transformers.__version__ - if version.parse(trans_version) >= version.parse("4.31.0"): - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaAttention, - llama_attention_forward_4_31,) - else: - # todo implement 4.28.0 ~ 4.30.2 - pass - @classmethod def load_convert(cls, q_k, *args, **kwargs): from .convert import ggml_convert_quant @@ -117,8 +97,6 @@ class _BaseAutoModelClass: model = ggml_convert_quant(model, qtype) model.config.update({"bigdl_transformers_low_bit": q_k}) - cls.optimize(model) - # add save_low_bit to pretrained model dynamically import types model.save_low_bit = types.MethodType(save_low_bit, model) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm.py b/python/llm/src/bigdl/llm/transformers/models/chatglm.py new file mode 100644 index 00000000..d8a55f19 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm.py @@ -0,0 +1,293 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is adapted from +# https://huggingface.co/THUDM/chatglm2-6b/blob/8eb45c842594b8473f291d0f94e7bbe86ffc67d8/modeling_chatglm.py +# + +import torch +from typing import Optional, Tuple, Union, List, Callable, Dict, Any +import torch.nn.functional as F + + +KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +KV_CACHE_ALLOC_MIN_LENGTH = 512 + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +def chatglm_attention_forward_8eb45c( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True +): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + cur_length, batch_size = query_layer.shape[0], query_layer.shape[1] + + if self.multi_query_attention: + key_length = key_layer.size(0) + query_group_size = self.num_attention_heads_per_partition // \ + self.num_multi_query_groups_per_partition + key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn] + key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1) + key_layer = key_layer.contiguous().view((batch_size, + self.num_attention_heads_per_partition, + key_length, + self.hidden_size_per_attention_head)) + value_layer = value_layer.permute(1, 2, 0, 3).unsqueeze(-3) + value_layer = value_layer.expand(-1, -1, query_group_size, -1, -1) + value_layer = value_layer.contiguous().view((batch_size, + self.num_attention_heads_per_partition, + key_length, + self.hidden_size_per_attention_head)) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + past_length = cache_k.size(2) + + if past_length + cur_length > self.max_cache_length: + self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH + self.kv_cache = (torch.empty(batch_size, + self.num_attention_heads_per_partition, + self.max_cache_length, + self.hidden_size_per_attention_head,), + torch.empty(batch_size, + self.num_attention_heads_per_partition, + self.max_cache_length, + self.hidden_size_per_attention_head,)) + self.kv_cache[0][:, :, :past_length, :] = cache_k + self.kv_cache[1][:, :, :past_length, :] = cache_v + self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer + self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer + + key_layer = self.kv_cache[0][:, :, :past_length + cur_length, :] + value_layer = self.kv_cache[1][:, :, :past_length + cur_length, :] + + elif use_cache: + self.max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ + + KV_CACHE_ALLOC_BLOCK_LENGTH + self.kv_cache = (torch.empty(batch_size, self.num_attention_heads_per_partition, + self.max_cache_length, self.hidden_size_per_attention_head,), + torch.empty(batch_size, self.num_attention_heads_per_partition, + self.max_cache_length, self.hidden_size_per_attention_head,)) + self.kv_cache[0][:, :, :cur_length, :] = key_layer + self.kv_cache[1][:, :, :cur_length, :] = value_layer + + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split('.')[0]) + if query_layer.size(0) > 1 and pytorch_major_version >= 2: + query_layer = query_layer.permute(1, 2, 0, 3) + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), + query_layer.size(0), key_layer.size(2)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], + output_size[2], output_size[3], dtype=query_layer.dtype, + device=query_layer.device + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: + attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], + device=attention_scores.device, dtype=torch.bool) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(1), + query_layer.size(0), value_layer.size(3)) + # change view [sk, b * np, hn] + value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer