[WIP] LLM: add kv cache support for internlm. (#9036)

* LLM: add kv cache support for internlm

* add internlm apply_rotary_pos_emb

* fix.

* fix style.
This commit is contained in:
Cengguang Zhang 2023-09-25 14:16:59 +08:00 committed by GitHub
parent 975da86e00
commit b4a1266ef0
3 changed files with 177 additions and 1 deletions

View file

@ -272,4 +272,12 @@ def optimize(model):
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention, transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention,
gptneox_attention_forward gptneox_attention_forward
) )
elif model.config.model_type == "internlm":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.internlm import internlm_attention_forward
convert_forward(model,
module.InternLMAttention,
internlm_attention_forward
)
return model return model

View file

@ -0,0 +1,168 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://huggingface.co/internlm/internlm-chat-7b/blob/659ed911eec1e26810f9854f19c5ec27854e9cf3/modeling_internlm.py
# which is licensed under Apache License 2.0:
#
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch InternLM model."""
import math
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def internlm_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None,
position_ids: Optional[torch.LongTensor]=None,
past_key_value: Optional[Tuple[torch.Tensor]]=None,
output_attentions: bool=False,
use_cache: bool=False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
query_states = self.q_proj(hidden_states) \
.view(bsz, q_len, self.num_heads, self.head_dim) \
.transpose(1, 2)
key_states = self.k_proj(hidden_states) \
.view(bsz, q_len, self.num_heads, self.head_dim) \
.transpose(1, 2)
value_states = self.v_proj(hidden_states) \
.view(bsz, q_len, self.num_heads, self.head_dim) \
.transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states,
key_states,
cos,
sin,
position_ids,
"internlm"
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(
bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device
)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(
bsz,
self.num_heads,
self.head_dim,
kv_seq_len,
max_cache_length,
dtype=key_states.dtype,
device=device
)
new_key_states[:] = key_states
new_value_states[:] = value_states
key_states = new_key_states
value_states = new_value_states
past_key_value = (key_states, value_states) if use_cache else None
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, "
f"but is {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
f"but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(
False,
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, "
f"but is {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value

View file

@ -71,7 +71,7 @@ def rotate_every_two(x):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
if model_family in ["llama", "baichuan"]: if model_family in ["llama", "baichuan", "internlm"]:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]