* Read the value of KV_CACHE_ALLOC_BLOCK_LENGTH from the environment variables. * Fix style
236 lines
8.8 KiB
Python
236 lines
8.8 KiB
Python
#
|
|
# Copyright 2016 The BigDL Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# Some parts of this file is adapted from
|
|
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/bloom/modeling_bloom.py
|
|
# which is licensed under Apache License 2.0:
|
|
#
|
|
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch BLOOM model."""
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch.nn import functional as F
|
|
from ipex_llm.transformers.models.utils import use_fused_layer_norm
|
|
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
|
|
|
import os
|
|
|
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
|
|
|
|
|
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool):
|
|
"""
|
|
Dropout add function
|
|
|
|
Args:
|
|
x (`torch.tensor`, *required*):
|
|
input tensor
|
|
residual (`torch.tensor`, *required*):
|
|
residual tensor
|
|
prob (`float`, *required*):
|
|
dropout probability
|
|
training (`bool`, *required*):
|
|
training mode
|
|
"""
|
|
out = F.dropout(x, p=prob, training=training)
|
|
out = residual + out
|
|
return out
|
|
|
|
|
|
def bloom_layer_norm_forward(self, hidden_states):
|
|
if use_fused_layer_norm(hidden_states, self.training):
|
|
import linear_q4_0
|
|
result = linear_q4_0.fused_layer_norm(hidden_states,
|
|
[self.weight.size(0)],
|
|
self.weight,
|
|
self.bias,
|
|
self.eps)
|
|
# if nelement == 0, means fused norm failed, go back to python implement.
|
|
if result.nelement != 0:
|
|
return result
|
|
input_dtype = hidden_states.dtype
|
|
result = F.layer_norm(hidden_states.to(self.weight.dtype),
|
|
self.normalized_shape, self.weight, self.bias, self.eps)
|
|
return result.to(input_dtype)
|
|
|
|
|
|
def bloom_attention_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
alibi: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]]=None,
|
|
head_mask: Optional[torch.Tensor]=None,
|
|
use_cache: bool=False,
|
|
output_attentions: bool=False,
|
|
):
|
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
|
|
|
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
|
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
|
|
|
batch_size, q_length, _, _ = query_layer.shape
|
|
|
|
query_layer = query_layer.transpose(1, 2).reshape(
|
|
batch_size * self.num_heads,
|
|
q_length,
|
|
self.head_dim
|
|
)
|
|
key_layer = key_layer.permute(0, 2, 3, 1).reshape(
|
|
batch_size * self.num_heads,
|
|
self.head_dim,
|
|
q_length
|
|
)
|
|
value_layer = value_layer.transpose(1, 2).reshape(
|
|
batch_size * self.num_heads,
|
|
q_length,
|
|
self.head_dim
|
|
)
|
|
_, _, kv_length = key_layer.shape
|
|
if layer_past is not None:
|
|
kv_length += layer_past[0].shape[-1]
|
|
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
|
key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim)
|
|
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
|
device = hidden_states.device
|
|
if layer_past is not None:
|
|
# reuse k, v, self_attention
|
|
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
|
|
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
|
if cache_k.stride()[1] < kv_length * cache_k.size(3):
|
|
# allocate new
|
|
new_cache_k, new_cache_v = extend_kv_cache(
|
|
batch_size,
|
|
self.num_heads,
|
|
self.head_dim,
|
|
cache_k.size(2),
|
|
kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
|
dtype=cache_k.dtype,
|
|
device=device
|
|
)
|
|
new_cache_k[:] = cache_k
|
|
new_cache_v[:] = cache_v
|
|
cache_k = new_cache_k
|
|
cache_v = new_cache_v
|
|
|
|
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
|
|
|
|
elif use_cache:
|
|
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
|
new_key_states, new_value_states = init_kv_cache(
|
|
batch_size,
|
|
self.num_heads,
|
|
self.head_dim,
|
|
kv_length,
|
|
max_cache_length,
|
|
dtype=key_layer.dtype,
|
|
device=device
|
|
)
|
|
new_key_states[:] = key_layer
|
|
new_value_states[:] = value_layer
|
|
key_layer = new_key_states
|
|
value_layer = new_value_states
|
|
|
|
query_layer = query_layer.view(batch_size*self.num_heads, -1, self.head_dim)
|
|
key_layer = key_layer.view(batch_size*self.num_heads, -1, self.head_dim).transpose(1, 2)
|
|
value_layer = value_layer.view(batch_size*self.num_heads, -1, self.head_dim)
|
|
_, _, kv_length = key_layer.shape
|
|
if use_cache is True:
|
|
present = (key_layer, value_layer)
|
|
else:
|
|
present = None
|
|
|
|
# [batch_size * num_heads, q_length, kv_length]
|
|
# we use `torch.Tensor.baddbmm`
|
|
# instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
|
matmul_result = alibi.baddbmm(
|
|
batch1=query_layer,
|
|
batch2=key_layer,
|
|
beta=self.beta,
|
|
alpha=self.inv_norm_factor,
|
|
)
|
|
|
|
# change view to [batch_size, num_heads, q_length, kv_length]
|
|
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
|
|
|
# cast attention scores to fp32,
|
|
# compute scaled softmax and cast back to initial dtype
|
|
# - [batch_size, num_heads, q_length, kv_length]
|
|
input_dtype = attention_scores.dtype
|
|
# `float16` has a minimum value of -65504.0,
|
|
# whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
|
if input_dtype == torch.float16:
|
|
attention_scores = attention_scores.to(torch.float)
|
|
attn_weights = torch.masked_fill(
|
|
attention_scores,
|
|
attention_mask,
|
|
torch.finfo(attention_scores.dtype).min
|
|
)
|
|
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
|
|
|
# [batch_size, num_heads, q_length, kv_length]
|
|
attention_probs = self.attention_dropout(attention_probs)
|
|
|
|
if head_mask is not None:
|
|
attention_probs = attention_probs * head_mask
|
|
|
|
# change view [batch_size x num_heads, q_length, kv_length]
|
|
attention_probs_reshaped = attention_probs.view(
|
|
batch_size * self.num_heads,
|
|
q_length,
|
|
kv_length
|
|
)
|
|
|
|
# matmul: [batch_size * num_heads, q_length, head_dim]
|
|
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
|
|
|
# change view [batch_size, q_length, num_heads * head_dim]
|
|
context_layer = self._merge_heads(context_layer)
|
|
|
|
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
|
if self.pretraining_tp > 1 and self.slow_but_exact:
|
|
slices = self.hidden_size / self.pretraining_tp
|
|
output_tensor = torch.zeros_like(context_layer)
|
|
for i in range(self.pretraining_tp):
|
|
output_tensor = output_tensor + F.linear(
|
|
context_layer[:, :, int(i * slices): int((i + 1) * slices)],
|
|
self.dense.weight[:, int(i * slices): int((i + 1) * slices)],
|
|
)
|
|
else:
|
|
output_tensor = self.dense(context_layer)
|
|
|
|
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
|
|
|
outputs = (output_tensor, present)
|
|
if output_attentions:
|
|
outputs += (attention_probs,)
|
|
|
|
return outputs
|