From b4a1266ef093d67cb7ea20e02ec6702f30159920 Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Mon, 25 Sep 2023 14:16:59 +0800 Subject: [PATCH] [WIP] LLM: add kv cache support for internlm. (#9036) * LLM: add kv cache support for internlm * add internlm apply_rotary_pos_emb * fix. * fix style. --- .../llm/src/bigdl/llm/transformers/convert.py | 8 + .../bigdl/llm/transformers/models/internlm.py | 168 ++++++++++++++++++ .../bigdl/llm/transformers/models/utils.py | 2 +- 3 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 python/llm/src/bigdl/llm/transformers/models/internlm.py diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 518c8a5d..07f929c3 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -272,4 +272,12 @@ def optimize(model): transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention, gptneox_attention_forward ) + elif model.config.model_type == "internlm": + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.internlm import internlm_attention_forward + convert_forward(model, + module.InternLMAttention, + internlm_attention_forward + ) return model diff --git a/python/llm/src/bigdl/llm/transformers/models/internlm.py b/python/llm/src/bigdl/llm/transformers/models/internlm.py new file mode 100644 index 00000000..7f3b7f7d --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/models/internlm.py @@ -0,0 +1,168 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://huggingface.co/internlm/internlm-chat-7b/blob/659ed911eec1e26810f9854f19c5ec27854e9cf3/modeling_internlm.py +# which is licensed under Apache License 2.0: +# +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch InternLM model.""" +import math +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn +from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb + + +KV_CACHE_ALLOC_BLOCK_LENGTH = 256 + + +def internlm_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor]=None, + position_ids: Optional[torch.LongTensor]=None, + past_key_value: Optional[Tuple[torch.Tensor]]=None, + output_attentions: bool=False, + use_cache: bool=False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + query_states = self.q_proj(hidden_states) \ + .view(bsz, q_len, self.num_heads, self.head_dim) \ + .transpose(1, 2) + key_states = self.k_proj(hidden_states) \ + .view(bsz, q_len, self.num_heads, self.head_dim) \ + .transpose(1, 2) + value_states = self.v_proj(hidden_states) \ + .view(bsz, q_len, self.num_heads, self.head_dim) \ + .transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + position_ids, + "internlm" + ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + cache_k = past_key_value[0] + cache_v = past_key_value[1] + if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + # allocate new + new_cache_k, new_cache_v = extend_kv_cache( + bsz, + self.num_heads, + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device + ) + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + cache_k = new_cache_k + cache_v = new_cache_v + + key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) + + elif use_cache: + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + new_key_states, new_value_states = init_kv_cache( + bsz, + self.num_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device + ) + new_key_states[:] = key_states + new_value_states[:] = value_states + key_states = new_key_states + value_states = new_value_states + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, " + f"but is {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " + f"but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + invalidInputError( + False, + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, " + f"but is {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index b47ad8e7..4489b268 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -71,7 +71,7 @@ def rotate_every_two(x): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family): - if model_family in ["llama", "baichuan"]: + if model_family in ["llama", "baichuan", "internlm"]: # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]