ipex-llm/python/llm/src/ipex_llm/transformers/models/bert.py
Wang, Jian4 9df70d95eb
Refactor bigdl.llm to ipex_llm (#24)
* Rename bigdl/llm to ipex_llm

* rm python/llm/src/bigdl

* from bigdl.llm to from ipex_llm
2024-03-22 15:41:21 +08:00

147 lines
5.9 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/v4.38.0/src/transformers/models/bert/modeling_bert.py
# which is licensed under Apache License 2.0:
#
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from typing import Optional, Tuple
from transformers.models.bert.modeling_bert import BertSelfAttention, BertEncoder
from ipex_llm.utils.common import invalidInputError
def merge_qkv(module: torch.nn.Module):
if isinstance(module, BertSelfAttention):
q_w = module.query.weight.data
k_w = module.key.weight.data
v_w = module.value.weight.data
q_b = module.query.bias.data
k_b = module.key.bias.data
v_b = module.value.bias.data
new_w = torch.cat([q_w, k_w, v_w], dim=0)
new_b = torch.cat([q_b, k_b, v_b], dim=-1)
qkv = torch.nn.Linear(0, 0, bias=True)
qkv.weight = torch.nn.Parameter(new_w, requires_grad=False)
qkv.bias = torch.nn.Parameter(new_b, requires_grad=False)
qkv.in_features = module.query.in_features
qkv.out_features = module.query.out_features * 3
module.qkv = qkv
del module.query
del module.key
del module.value
def self_attention_forward(
self: torch.nn.Module,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
):
invalidInputError(encoder_hidden_states is None,
"cross attention is not supported")
invalidInputError(not self.is_decoder,
"bert decoder is not supported")
invalidInputError(self.position_embedding_type == "absolute",
"relative query/key is not supported")
qkv_output = self.qkv(hidden_states)
(query_layer, key_layer, value_layer) = torch.chunk(qkv_output, 3, -1)
query_layer = self.transpose_for_scores(query_layer)
key_layer = self.transpose_for_scores(key_layer)
value_layer = self.transpose_for_scores(value_layer)
if past_key_value is not None:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size),
key_layer.transpose(-1, -2))
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
def encoder_forward(
self: torch.nn.Module,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
):
if not attention_mask.any():
attention_mask = None
return BertEncoder.forward(
self=self,
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)