diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index f416f161..3c45d8b1 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -287,4 +287,12 @@ def optimize(model): module.InternLMAttention, internlm_attention_forward ) + elif model.config.model_type == "aquila": + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.aquila import aquila_attention_forward + convert_forward(model, + module.AquilaAttention, + aquila_attention_forward + ) return model diff --git a/python/llm/src/bigdl/llm/transformers/models/aquila.py b/python/llm/src/bigdl/llm/transformers/models/aquila.py new file mode 100644 index 00000000..84abb6b8 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/models/aquila.py @@ -0,0 +1,157 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://huggingface.co/BAAI/AquilaChat-7B/blob/main/modeling_aquila.py +# +# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb +from bigdl.dllib.utils import log4Error + +KV_CACHE_ALLOC_BLOCK_LENGTH = 256 + + +def aquila_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states)\ + .view(bsz, q_len, self.num_heads, self.head_dim)\ + .transpose(1, 2) + key_states = self.k_proj(hidden_states)\ + .view(bsz, q_len, self.num_heads, self.head_dim)\ + .transpose(1, 2) + value_states = self.v_proj(hidden_states)\ + .view(bsz, q_len, self.num_heads, self.head_dim)\ + .transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "aquila") + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + cache_k = past_key_value[0] + cache_v = past_key_value[1] + if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + # allocate new + new_cache_k, new_cache_v = extend_kv_cache(bsz, + self.num_heads, # Support GQA + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=hidden_states.device) + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + cache_k = new_cache_k + cache_v = new_cache_v + + key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) + + elif use_cache: + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + new_key_states, new_value_states = init_kv_cache(bsz, + self.num_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=hidden_states.device) + new_key_states[:] = key_states + new_value_states[:] = value_states + key_states = new_key_states + value_states = new_value_states + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + attn_weights = torch.clamp(attn_weights, min=-1024., max=1024.) + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + log4Error.invalidInputError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, " + f"but is {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + log4Error.invalidInputError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " + f"but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, + torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)\ + .to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + log4Error.invalidInputError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, " + f"but is {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 4489b268..f4f1372b 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -71,7 +71,7 @@ def rotate_every_two(x): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family): - if model_family in ["llama", "baichuan", "internlm"]: + if model_family in ["llama", "baichuan", "internlm", "aquila"]: # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]