LLM: add bloom kv cache support (#9012)
* LLM: add bloom kv cache support * fix style.
This commit is contained in:
parent
156af15d1e
commit
b3cad7de57
2 changed files with 223 additions and 0 deletions
|
|
@ -173,6 +173,14 @@ def optimize(model):
|
|||
module.SelfAttention,
|
||||
chatglm_attention_forward
|
||||
)
|
||||
elif "bloom" in model.config._name_or_path:
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.bloom import bloom_attention_forward
|
||||
convert_forward(model,
|
||||
module.BloomAttention,
|
||||
bloom_attention_forward
|
||||
)
|
||||
elif "falcon" in model.config._name_or_path:
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
|
|||
215
python/llm/src/bigdl/llm/transformers/models/bloom.py
Normal file
215
python/llm/src/bigdl/llm/transformers/models/bloom.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/bloom/modeling_bloom.py
|
||||
# which is licensed under Apache License 2.0:
|
||||
#
|
||||
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BLOOM model."""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import functional as F
|
||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
|
||||
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
||||
|
||||
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool):
|
||||
"""
|
||||
Dropout add function
|
||||
|
||||
Args:
|
||||
x (`torch.tensor`, *required*):
|
||||
input tensor
|
||||
residual (`torch.tensor`, *required*):
|
||||
residual tensor
|
||||
prob (`float`, *required*):
|
||||
dropout probability
|
||||
training (`bool`, *required*):
|
||||
training mode
|
||||
"""
|
||||
out = F.dropout(x, p=prob, training=training)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
def bloom_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]]=None,
|
||||
head_mask: Optional[torch.Tensor]=None,
|
||||
use_cache: bool=False,
|
||||
output_attentions: bool=False,
|
||||
):
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
|
||||
batch_size, q_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(
|
||||
batch_size * self.num_heads,
|
||||
q_length,
|
||||
self.head_dim
|
||||
)
|
||||
key_layer = key_layer.permute(0, 2, 3, 1).reshape(
|
||||
batch_size * self.num_heads,
|
||||
self.head_dim,
|
||||
q_length
|
||||
)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(
|
||||
batch_size * self.num_heads,
|
||||
q_length,
|
||||
self.head_dim
|
||||
)
|
||||
_, _, kv_length = key_layer.shape
|
||||
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||
key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||
device = hidden_states.device
|
||||
if layer_past is not None:
|
||||
# reuse k, v, self_attention
|
||||
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
|
||||
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(
|
||||
batch_size,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
dtype=cache_k.dtype,
|
||||
device=device
|
||||
)
|
||||
new_cache_k[:] = cache_k
|
||||
new_cache_v[:] = cache_v
|
||||
cache_k = new_cache_k
|
||||
cache_v = new_cache_v
|
||||
|
||||
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
|
||||
|
||||
elif use_cache:
|
||||
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_key_states, new_value_states = create_kv_cache(
|
||||
batch_size,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
kv_length,
|
||||
max_cache_length,
|
||||
dtype=key_layer.dtype,
|
||||
device=device
|
||||
)
|
||||
new_key_states[:] = key_layer
|
||||
new_value_states[:] = value_layer
|
||||
key_layer = new_key_states
|
||||
value_layer = new_value_states
|
||||
|
||||
query_layer = query_layer.view(batch_size*self.num_heads, -1, self.head_dim)
|
||||
key_layer = key_layer.view(batch_size*self.num_heads, -1, self.head_dim).transpose(1, 2)
|
||||
value_layer = value_layer.view(batch_size*self.num_heads, -1, self.head_dim)
|
||||
_, _, kv_length = key_layer.shape
|
||||
if use_cache is True:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
# [batch_size * num_heads, q_length, kv_length]
|
||||
# we use `torch.Tensor.baddbmm`
|
||||
# instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
||||
matmul_result = alibi.baddbmm(
|
||||
batch1=query_layer,
|
||||
batch2=key_layer,
|
||||
beta=self.beta,
|
||||
alpha=self.inv_norm_factor,
|
||||
)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
||||
|
||||
# cast attention scores to fp32,
|
||||
# compute scaled softmax and cast back to initial dtype
|
||||
# - [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
# `float16` has a minimum value of -65504.0,
|
||||
# whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||
if input_dtype == torch.float16:
|
||||
attention_scores = attention_scores.to(torch.float)
|
||||
attn_weights = torch.masked_fill(
|
||||
attention_scores,
|
||||
attention_mask,
|
||||
torch.finfo(attention_scores.dtype).min
|
||||
)
|
||||
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
||||
|
||||
# [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
# change view [batch_size x num_heads, q_length, kv_length]
|
||||
attention_probs_reshaped = attention_probs.view(
|
||||
batch_size * self.num_heads,
|
||||
q_length,
|
||||
kv_length
|
||||
)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
||||
|
||||
# change view [batch_size, q_length, num_heads * head_dim]
|
||||
context_layer = self._merge_heads(context_layer)
|
||||
|
||||
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
slices = self.hidden_size / self.pretraining_tp
|
||||
output_tensor = torch.zeros_like(context_layer)
|
||||
for i in range(self.pretraining_tp):
|
||||
output_tensor = output_tensor + F.linear(
|
||||
context_layer[:, :, int(i * slices): int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices): int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||
|
||||
outputs = (output_tensor, present)
|
||||
if output_attentions:
|
||||
outputs += (attention_probs,)
|
||||
|
||||
return outputs
|
||||
Loading…
Reference in a new issue