diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 29670b1d..5cd679cf 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -33,6 +33,14 @@ def save_low_bit(self, *args, **kwargs): self.save_pretrained(*args, **kwargs) +def convert_forward(m, target_m, new_forward): + for _, sub_m in m.named_children(): + if isinstance(sub_m, target_m): + bound_method = new_forward.__get__(sub_m, sub_m.__class__) + setattr(sub_m, "forward", bound_method) + convert_forward(sub_m, target_m, new_forward) + + class _BaseAutoModelClass: HF_MODEL = None @@ -80,6 +88,20 @@ class _BaseAutoModelClass: return model + @classmethod + def optimize(cls, model): + from packaging import version + from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 + trans_version = transformers.__version__ + if version.parse(trans_version) >= version.parse("4.31.0"): + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaAttention, + llama_attention_forward_4_31,) + else: + # todo implement 4.28.0 ~ 4.30.2 + pass + @classmethod def load_convert(cls, q_k, *args, **kwargs): from .convert import ggml_convert_quant @@ -92,6 +114,8 @@ class _BaseAutoModelClass: model = ggml_convert_quant(model, qtype) model.config.update({"bigdl_transformers_low_bit": q_k}) + cls.optimize(model) + # add save_low_bit to pretrained model dynamically import types model.save_low_bit = types.MethodType(save_low_bit, model) diff --git a/python/llm/src/bigdl/llm/transformers/models/__init__.py b/python/llm/src/bigdl/llm/transformers/models/__init__.py new file mode 100644 index 00000000..9c4dd6d6 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/models/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py new file mode 100644 index 00000000..674db7e0 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -0,0 +1,202 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/llama/modeling_llama.py +# which is licensed under Apache License 2.0: +# +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from typing import Optional, Tuple +import math +import torch.nn.functional as F +from bigdl.llm.utils.common import invalidInputError + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states + go from (batch, num_key_value_heads, seqlen, head_dim) to + (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +KV_CACHE_ALLOC_BLOCK_LENGTH = 256 + + +def llama_attention_forward_4_31( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) + // self.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) + for i in range(self.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) + for i in range(self.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) + for i in range(self.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, + self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + # key_states = torch.cat([past_key_value[0], key_states], dim=2) + # value_states = torch.cat([past_key_value[1], value_states], dim=2) + if kv_seq_len > self.max_cache_length: + new_cache_key = torch.empty(bsz, self.num_heads, + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim) + new_cache_key[:, :, :kv_seq_len-1, :] = self.kv_cache[0][:, :, :kv_seq_len-1, :] + self.kv_cache[0] = new_cache_key + + new_cache_value = torch.empty(bsz, self.num_heads, + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim) + new_cache_value[:, :, :kv_seq_len-1, :] = self.kv_cache[1][:, :, :kv_seq_len-1, :] + self.kv_cache[1] = new_cache_value + self.max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + + self.kv_cache[0][:, :, kv_seq_len-1:kv_seq_len, :] = key_states + self.kv_cache[1][:, :, kv_seq_len-1:kv_seq_len, :] = value_states + key_states = self.kv_cache[0][:, :, :kv_seq_len, :] + value_states = self.kv_cache[1][:, :, :kv_seq_len, :] + elif use_cache: + # first token case + self.max_cache_length = max(min(self.max_position_embeddings, 2 * kv_seq_len), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH) + self.kv_cache = (torch.empty(bsz, self.num_heads, self.max_cache_length, self.head_dim), + torch.empty(bsz, self.num_heads, self.max_cache_length, self.head_dim)) + self.kv_cache[0][:, :, :kv_seq_len, :] = key_states + self.kv_cache[1][:, :, :kv_seq_len, :] = value_states + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + attn_weights_size = (bsz, self.num_heads, q_len, kv_seq_len) + if attn_weights.size() != attn_weights_size: + invalidInputError(False, + f"Attention weights should be of size {attn_weights_size}, " + f"but is {attn_weights.size()}") + + if attention_mask is not None: + attn_mask_size = (bsz, 1, q_len, kv_seq_len) + if attention_mask.size() != attn_mask_size: + invalidInputError(False, + f"Attention mask should be of size {attn_mask_size}, " + f"but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) + if attn_output.size() != attn_output_size: + invalidInputError(False, + f"`attn_output` should be of size {attn_output_size}," + f" but is {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value