ipex-llm/python/llm/src/ipex_llm/transformers/models/bloom.py
Keyan (Kyrie) Zhang 585c174e92
Read the value of KV_CACHE_ALLOC_BLOCK_LENGTH from the environment variables (#10707)
* Read the value of KV_CACHE_ALLOC_BLOCK_LENGTH from the environment variables.

* Fix style
2024-04-10 10:48:46 +08:00

236 lines
8.8 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/bloom/modeling_bloom.py
# which is licensed under Apache License 2.0:
#
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BLOOM model."""
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from torch.nn import functional as F
from ipex_llm.transformers.models.utils import use_fused_layer_norm
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool):
"""
Dropout add function
Args:
x (`torch.tensor`, *required*):
input tensor
residual (`torch.tensor`, *required*):
residual tensor
prob (`float`, *required*):
dropout probability
training (`bool`, *required*):
training mode
"""
out = F.dropout(x, p=prob, training=training)
out = residual + out
return out
def bloom_layer_norm_forward(self, hidden_states):
if use_fused_layer_norm(hidden_states, self.training):
import linear_q4_0
result = linear_q4_0.fused_layer_norm(hidden_states,
[self.weight.size(0)],
self.weight,
self.bias,
self.eps)
# if nelement == 0, means fused norm failed, go back to python implement.
if result.nelement != 0:
return result
input_dtype = hidden_states.dtype
result = F.layer_norm(hidden_states.to(self.weight.dtype),
self.normalized_shape, self.weight, self.bias, self.eps)
return result.to(input_dtype)
def bloom_attention_forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]]=None,
head_mask: Optional[torch.Tensor]=None,
use_cache: bool=False,
output_attentions: bool=False,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(
batch_size * self.num_heads,
q_length,
self.head_dim
)
key_layer = key_layer.permute(0, 2, 3, 1).reshape(
batch_size * self.num_heads,
self.head_dim,
q_length
)
value_layer = value_layer.transpose(1, 2).reshape(
batch_size * self.num_heads,
q_length,
self.head_dim
)
_, _, kv_length = key_layer.shape
if layer_past is not None:
kv_length += layer_past[0].shape[-1]
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim)
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
device = hidden_states.device
if layer_past is not None:
# reuse k, v, self_attention
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] < kv_length * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(
batch_size,
self.num_heads,
self.head_dim,
cache_k.size(2),
kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device
)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
elif use_cache:
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(
batch_size,
self.num_heads,
self.head_dim,
kv_length,
max_cache_length,
dtype=key_layer.dtype,
device=device
)
new_key_states[:] = key_layer
new_value_states[:] = value_layer
key_layer = new_key_states
value_layer = new_value_states
query_layer = query_layer.view(batch_size*self.num_heads, -1, self.head_dim)
key_layer = key_layer.view(batch_size*self.num_heads, -1, self.head_dim).transpose(1, 2)
value_layer = value_layer.view(batch_size*self.num_heads, -1, self.head_dim)
_, _, kv_length = key_layer.shape
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
# [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm`
# instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
matmul_result = alibi.baddbmm(
batch1=query_layer,
batch2=key_layer,
beta=self.beta,
alpha=self.inv_norm_factor,
)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
# cast attention scores to fp32,
# compute scaled softmax and cast back to initial dtype
# - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0,
# whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16:
attention_scores = attention_scores.to(torch.float)
attn_weights = torch.masked_fill(
attention_scores,
attention_mask,
torch.finfo(attention_scores.dtype).min
)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# change view [batch_size x num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(
batch_size * self.num_heads,
q_length,
kv_length
)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
# change view [batch_size, q_length, num_heads * head_dim]
context_layer = self._merge_heads(context_layer)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices): int((i + 1) * slices)],
self.dense.weight[:, int(i * slices): int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
outputs = (output_tensor, present)
if output_attentions:
outputs += (attention_probs,)
return outputs