LLM: add bloom kv cache support (#9012)

* LLM: add bloom kv cache support

* fix style.
This commit is contained in:
Cengguang Zhang 2023-09-20 21:10:53 +08:00 committed by GitHub
parent 156af15d1e
commit b3cad7de57
2 changed files with 223 additions and 0 deletions

View file

@ -173,6 +173,14 @@ def optimize(model):
module.SelfAttention,
chatglm_attention_forward
)
elif "bloom" in model.config._name_or_path:
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.bloom import bloom_attention_forward
convert_forward(model,
module.BloomAttention,
bloom_attention_forward
)
elif "falcon" in model.config._name_or_path:
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)

View file

@ -0,0 +1,215 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/bloom/modeling_bloom.py
# which is licensed under Apache License 2.0:
#
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BLOOM model."""
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from torch.nn import functional as F
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool):
"""
Dropout add function
Args:
x (`torch.tensor`, *required*):
input tensor
residual (`torch.tensor`, *required*):
residual tensor
prob (`float`, *required*):
dropout probability
training (`bool`, *required*):
training mode
"""
out = F.dropout(x, p=prob, training=training)
out = residual + out
return out
def bloom_attention_forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]]=None,
head_mask: Optional[torch.Tensor]=None,
use_cache: bool=False,
output_attentions: bool=False,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(
batch_size * self.num_heads,
q_length,
self.head_dim
)
key_layer = key_layer.permute(0, 2, 3, 1).reshape(
batch_size * self.num_heads,
self.head_dim,
q_length
)
value_layer = value_layer.transpose(1, 2).reshape(
batch_size * self.num_heads,
q_length,
self.head_dim
)
_, _, kv_length = key_layer.shape
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim)
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
device = hidden_states.device
if layer_past is not None:
# reuse k, v, self_attention
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = create_kv_cache(
batch_size,
self.num_heads,
self.head_dim,
cache_k.size(2),
kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device
)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
elif use_cache:
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(
batch_size,
self.num_heads,
self.head_dim,
kv_length,
max_cache_length,
dtype=key_layer.dtype,
device=device
)
new_key_states[:] = key_layer
new_value_states[:] = value_layer
key_layer = new_key_states
value_layer = new_value_states
query_layer = query_layer.view(batch_size*self.num_heads, -1, self.head_dim)
key_layer = key_layer.view(batch_size*self.num_heads, -1, self.head_dim).transpose(1, 2)
value_layer = value_layer.view(batch_size*self.num_heads, -1, self.head_dim)
_, _, kv_length = key_layer.shape
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
# [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm`
# instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
matmul_result = alibi.baddbmm(
batch1=query_layer,
batch2=key_layer,
beta=self.beta,
alpha=self.inv_norm_factor,
)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
# cast attention scores to fp32,
# compute scaled softmax and cast back to initial dtype
# - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0,
# whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16:
attention_scores = attention_scores.to(torch.float)
attn_weights = torch.masked_fill(
attention_scores,
attention_mask,
torch.finfo(attention_scores.dtype).min
)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# change view [batch_size x num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(
batch_size * self.num_heads,
q_length,
kv_length
)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
# change view [batch_size, q_length, num_heads * head_dim]
context_layer = self._merge_heads(context_layer)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices): int((i + 1) * slices)],
self.dense.weight[:, int(i * slices): int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
outputs = (output_tensor, present)
if output_attentions:
outputs += (attention_probs,)
return outputs