diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index daf4588b..f0e20881 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -173,4 +173,24 @@ def optimize(model): chatglm_attention_forward ) + elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696: + # baichuan2 + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward + convert_forward(model, + module.Attention, + baichuan_attention_forward + ) + + elif model.config.model_type == "baichuan": + # baichuan1 + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward + convert_forward(model, + module.Attention, + baichuan_attention_forward + ) + return model diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan.py b/python/llm/src/bigdl/llm/transformers/models/baichuan.py new file mode 100644 index 00000000..fb3f55d1 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan.py @@ -0,0 +1,135 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is adapted from +# https://huggingface.co/baichuan-inc/Baichuan-7B/blob/c1a5c7d5b7f50ecc51bb0e08150a9f12e5656756/modeling_baichuan.py + + +import math +from typing import List, Optional, Tuple, Union +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb + +KV_CACHE_ALLOC_BLOCK_LENGTH = 256 + + +def baichuan_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + + proj = self.W_pack(hidden_states) + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + # batch_size x source_len x hidden_size + query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + # batch_size x target_len x head_size + key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + # batch_size x source_len x hidden_size + value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "baichuan") + # [bsz, nh, t, hd] + + # if past_key_value is not None: + # # reuse k, v, self_attention + # key_states = torch.cat([past_key_value[0], key_states], dim=2) + # value_states = torch.cat([past_key_value[1], value_states], dim=2) + if past_key_value is not None: + # reuse k, v, self_attention + cache_k = past_key_value[0] + cache_v = past_key_value[1] + if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + # allocate new + new_cache_k, new_cache_v = create_kv_cache(bsz, + self.num_heads, + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + cache_k = new_cache_k + cache_v = new_cache_v + + key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) + + elif use_cache: + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + new_key_states, new_value_states = create_kv_cache(bsz, + self.num_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) + new_key_states[:] = key_states + new_value_states[:] = value_states + key_states = new_key_states + value_states = new_value_states + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + invalidInputError(False, + f"Attention weights should be of size " + f"{(bsz, self.num_heads, q_len, kv_seq_len)}" + f", but is {attn_weights.size()}") + + if attention_mask is not None: + invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " + f"but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), + f"`attn_output` should be of size " + f"{(bsz, self.num_heads, q_len, self.head_dim)}," + f"but is {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py new file mode 100644 index 00000000..4bc0410a --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -0,0 +1,135 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is adapted from +# https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/cb7fc748b78b7ea99772e4cf76db155729ce774e/modeling_baichuan.py + + +import math +from typing import List, Optional, Tuple, Union +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb +from transformers.utils import logging, ContextManagers +logger = logging.get_logger(__name__) + +try: + from xformers import ops as xops +except ImportError: + xops = None + logger.warning( + "Xformers is not installed correctly. If you want to use memory_efficient_attention to " + "accelerate training use the following command to install Xformers\npip install xformers." + ) + + +KV_CACHE_ALLOC_BLOCK_LENGTH = 256 + + +def baichuan_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + + proj = self.W_pack(hidden_states) + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + # batch_size x source_len x hidden_size + query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + # batch_size x target_len x head_size + key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + # batch_size x source_len x hidden_size + value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "baichuan") + # [bsz, nh, t, hd] + + # if past_key_value is not None: + # # reuse k, v, self_attention + # key_states = torch.cat([past_key_value[0], key_states], dim=2) + # value_states = torch.cat([past_key_value[1], value_states], dim=2) + if past_key_value is not None: + # reuse k, v, self_attention + cache_k = past_key_value[0] + cache_v = past_key_value[1] + if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + # allocate new + new_cache_k, new_cache_v = create_kv_cache(bsz, + self.num_heads, + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + cache_k = new_cache_k + cache_v = new_cache_v + + key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) + + elif use_cache: + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + new_key_states, new_value_states = create_kv_cache(bsz, + self.num_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) + new_key_states[:] = key_states + new_value_states[:] = value_states + key_states = new_key_states + value_states = new_value_states + + past_key_value = (key_states, value_states) if use_cache else None + + if xops is not None and self.training: + attn_weights = None + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = xops.memory_efficient_attention( + query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask() + ) + else: + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, + enable_mem_efficient=True): + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, + attn_mask=attention_mask) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm.py b/python/llm/src/bigdl/llm/transformers/models/chatglm.py index 89525697..6c1a0a8a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm.py @@ -67,8 +67,6 @@ def attention_fn( cache_v = cache_v.permute(1, 2, 0, 3) past_length = cache_k.size(2) if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): - if device.type == 'xpu': - torch.xpu.empty_cache() max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH new_cache_k, new_cache_v = create_kv_cache(batch_size, self.num_attention_heads_per_partition, diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 5de558e9..d43452cb 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -151,8 +151,6 @@ def chatglm2_attention_forward_8eb45c( past_length = cache_k.size(2) if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): - if device.type == 'xpu': - torch.xpu.empty_cache() max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH new_cache_k, new_cache_v = create_kv_cache(batch_size, self.num_attention_heads_per_partition, diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 415ca4e0..212abc2a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -38,24 +38,7 @@ import math import torch.nn.functional as F from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed +from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -122,15 +105,13 @@ def llama_attention_forward_4_31( kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids) + cos, sin, position_ids, "llama") if past_key_value is not None: # reuse k, v, self_attention cache_k = past_key_value[0] cache_v = past_key_value[1] if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): - if device.type == 'xpu': - torch.xpu.empty_cache() # allocate new new_cache_k, new_cache_v = create_kv_cache(bsz, self.num_key_value_heads, # Support GQA diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 58765e2a..6f2464c6 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -15,9 +15,12 @@ # import torch +from bigdl.llm.utils.common import invalidInputError def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): + if device.type == 'xpu': + torch.xpu.empty_cache() key_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim, dtype=dtype, device=device) @@ -46,3 +49,25 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states): new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0) new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states return new_cache_k, new_cache_v + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family): + if model_family in ["llama", "baichuan"]: + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + else: + invalidInputError(False, + f"{model_family} is not supported.")