LLM: Support optimized kv_cache for baichuan family (#8997)

* add initial support for baichuan attantion

* support baichuan1

* update based on comment

* update based on comment

* support baichuan2

* update link, change how to jusge baichuan2

* fix style

* add model parameter for pob emb

* update based on comment
This commit is contained in:
Ruonan Wang 2023-09-19 15:38:54 +08:00 committed by GitHub
parent 37bb0cbf8f
commit 004c45c2be
7 changed files with 317 additions and 25 deletions

View file

@ -173,4 +173,24 @@ def optimize(model):
chatglm_attention_forward chatglm_attention_forward
) )
elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
# baichuan2
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward
convert_forward(model,
module.Attention,
baichuan_attention_forward
)
elif model.config.model_type == "baichuan":
# baichuan1
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward
convert_forward(model,
module.Attention,
baichuan_attention_forward
)
return model return model

View file

@ -0,0 +1,135 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is adapted from
# https://huggingface.co/baichuan-inc/Baichuan-7B/blob/c1a5c7d5b7f50ecc51bb0e08150a9f12e5656756/modeling_baichuan.py
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def baichuan_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
proj = self.W_pack(hidden_states)
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
# batch_size x source_len x hidden_size
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# batch_size x target_len x head_size
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# batch_size x source_len x hidden_size
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "baichuan")
# [bsz, nh, t, hd]
# if past_key_value is not None:
# # reuse k, v, self_attention
# key_states = torch.cat([past_key_value[0], key_states], dim=2)
# value_states = torch.cat([past_key_value[1], value_states], dim=2)
if past_key_value is not None:
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz,
self.num_heads,
self.head_dim,
kv_seq_len,
max_cache_length,
dtype=key_states.dtype,
device=device)
new_key_states[:] = key_states
new_value_states[:] = value_states
key_states = new_key_states
value_states = new_value_states
past_key_value = (key_states, value_states) if use_cache else None
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}"
f", but is {attn_weights.size()}")
if attention_mask is not None:
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
f"`attn_output` should be of size "
f"{(bsz, self.num_heads, q_len, self.head_dim)},"
f"but is {attn_output.size()}")
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value

View file

@ -0,0 +1,135 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is adapted from
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/cb7fc748b78b7ea99772e4cf76db155729ce774e/modeling_baichuan.py
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
from transformers.utils import logging, ContextManagers
logger = logging.get_logger(__name__)
try:
from xformers import ops as xops
except ImportError:
xops = None
logger.warning(
"Xformers is not installed correctly. If you want to use memory_efficient_attention to "
"accelerate training use the following command to install Xformers\npip install xformers."
)
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def baichuan_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
proj = self.W_pack(hidden_states)
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
# batch_size x source_len x hidden_size
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# batch_size x target_len x head_size
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# batch_size x source_len x hidden_size
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "baichuan")
# [bsz, nh, t, hd]
# if past_key_value is not None:
# # reuse k, v, self_attention
# key_states = torch.cat([past_key_value[0], key_states], dim=2)
# value_states = torch.cat([past_key_value[1], value_states], dim=2)
if past_key_value is not None:
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz,
self.num_heads,
self.head_dim,
kv_seq_len,
max_cache_length,
dtype=key_states.dtype,
device=device)
new_key_states[:] = key_states
new_value_states[:] = value_states
key_states = new_key_states
value_states = new_value_states
past_key_value = (key_states, value_states) if use_cache else None
if xops is not None and self.training:
attn_weights = None
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_output = xops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
)
else:
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True,
enable_mem_efficient=True):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states,
attn_mask=attention_mask)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value

View file

@ -67,8 +67,6 @@ def attention_fn(
cache_v = cache_v.permute(1, 2, 0, 3) cache_v = cache_v.permute(1, 2, 0, 3)
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = create_kv_cache(batch_size, new_cache_k, new_cache_v = create_kv_cache(batch_size,
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,

View file

@ -151,8 +151,6 @@ def chatglm2_attention_forward_8eb45c(
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = create_kv_cache(batch_size, new_cache_k, new_cache_v = create_kv_cache(batch_size,
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,

View file

@ -38,24 +38,7 @@ import math
import torch.nn.functional as F import torch.nn.functional as F
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -122,15 +105,13 @@ def llama_attention_forward_4_31(
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids) cos, sin, position_ids, "llama")
if past_key_value is not None: if past_key_value is not None:
# reuse k, v, self_attention # reuse k, v, self_attention
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = create_kv_cache(bsz,
self.num_key_value_heads, # Support GQA self.num_key_value_heads, # Support GQA

View file

@ -15,9 +15,12 @@
# #
import torch import torch
from bigdl.llm.utils.common import invalidInputError
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
if device.type == 'xpu':
torch.xpu.empty_cache()
key_cache_storage = torch.empty(batch_size, num_heads, key_cache_storage = torch.empty(batch_size, num_heads,
max_length, head_dim, max_length, head_dim,
dtype=dtype, device=device) dtype=dtype, device=device)
@ -46,3 +49,25 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0) new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states
return new_cache_k, new_cache_v return new_cache_k, new_cache_v
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
if model_family in ["llama", "baichuan"]:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
else:
invalidInputError(False,
f"{model_family} is not supported.")