diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index d99f46ec..6bb1938b 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -580,6 +580,13 @@ def _optimize_pre(model): del module.q_proj del module.k_proj model.apply(merge_qk_proj_func) + # for bge-large + if model.config.model_type == 'bert' and ( + not model.config.is_decoder and + model.config.position_embedding_type == "absolute" + ): + from bigdl.llm.transformers.models.bert import merge_qkv + model.apply(merge_qkv) return model @@ -1228,4 +1235,19 @@ def _optimize_post(model, lightweight_bmm=False): # module.YuanMLP, # yuan_mlp_forward # ) + elif model.config.model_type == 'bert' and ( + not model.config.is_decoder and + model.config.position_embedding_type == "absolute" + ): + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.bert import self_attention_forward + from bigdl.llm.transformers.models.bert import encoder_forward + convert_forward(model, + module.BertSelfAttention, + self_attention_forward) + convert_forward(model, + module.BertEncoder, + encoder_forward) + return model diff --git a/python/llm/src/bigdl/llm/transformers/models/bert.py b/python/llm/src/bigdl/llm/transformers/models/bert.py new file mode 100644 index 00000000..27abd988 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/models/bert.py @@ -0,0 +1,147 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/v4.38.0/src/transformers/models/bert/modeling_bert.py +# which is licensed under Apache License 2.0: +# +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +from typing import Optional, Tuple +from transformers.models.bert.modeling_bert import BertSelfAttention, BertEncoder +from bigdl.llm.utils.common import invalidInputError + + +def merge_qkv(module: torch.nn.Module): + if isinstance(module, BertSelfAttention): + q_w = module.query.weight.data + k_w = module.key.weight.data + v_w = module.value.weight.data + q_b = module.query.bias.data + k_b = module.key.bias.data + v_b = module.value.bias.data + new_w = torch.cat([q_w, k_w, v_w], dim=0) + new_b = torch.cat([q_b, k_b, v_b], dim=-1) + qkv = torch.nn.Linear(0, 0, bias=True) + qkv.weight = torch.nn.Parameter(new_w, requires_grad=False) + qkv.bias = torch.nn.Parameter(new_b, requires_grad=False) + qkv.in_features = module.query.in_features + qkv.out_features = module.query.out_features * 3 + module.qkv = qkv + del module.query + del module.key + del module.value + + +def self_attention_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, +): + invalidInputError(encoder_hidden_states is None, + "cross attention is not supported") + invalidInputError(not self.is_decoder, + "bert decoder is not supported") + invalidInputError(self.position_embedding_type == "absolute", + "relative query/key is not supported") + + qkv_output = self.qkv(hidden_states) + (query_layer, key_layer, value_layer) = torch.chunk(qkv_output, 3, -1) + query_layer = self.transpose_for_scores(query_layer) + key_layer = self.transpose_for_scores(key_layer) + value_layer = self.transpose_for_scores(value_layer) + + if past_key_value is not None: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), + key_layer.transpose(-1, -2)) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +def encoder_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, +): + if not attention_mask.any(): + attention_mask = None + return BertEncoder.forward( + self=self, + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index bc79942e..77117bf5 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -353,10 +353,7 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool): not training and not x.requires_grad and device in ["arc", "flex", "pvc", "mtl"] # fused layer norm cannot run on UHD - and ( - device == "mtl" # fused layer norm conflicts with XMX, so disable it when using XMX - or x.numel() // x.size(-1) == 1 - ) + and x.numel() // x.size(-1) == 1 # fused layer norm is slower in first token )