optimize bge large performance (#10324)

This commit is contained in:
Yishuo Wang 2024-03-05 17:06:03 +08:00 committed by GitHub
parent 178eea5009
commit 0011ff9f64
3 changed files with 170 additions and 4 deletions

View file

@ -580,6 +580,13 @@ def _optimize_pre(model):
del module.q_proj
del module.k_proj
model.apply(merge_qk_proj_func)
# for bge-large
if model.config.model_type == 'bert' and (
not model.config.is_decoder and
model.config.position_embedding_type == "absolute"
):
from bigdl.llm.transformers.models.bert import merge_qkv
model.apply(merge_qkv)
return model
@ -1228,4 +1235,19 @@ def _optimize_post(model, lightweight_bmm=False):
# module.YuanMLP,
# yuan_mlp_forward
# )
elif model.config.model_type == 'bert' and (
not model.config.is_decoder and
model.config.position_embedding_type == "absolute"
):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.bert import self_attention_forward
from bigdl.llm.transformers.models.bert import encoder_forward
convert_forward(model,
module.BertSelfAttention,
self_attention_forward)
convert_forward(model,
module.BertEncoder,
encoder_forward)
return model

View file

@ -0,0 +1,147 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/v4.38.0/src/transformers/models/bert/modeling_bert.py
# which is licensed under Apache License 2.0:
#
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from typing import Optional, Tuple
from transformers.models.bert.modeling_bert import BertSelfAttention, BertEncoder
from bigdl.llm.utils.common import invalidInputError
def merge_qkv(module: torch.nn.Module):
if isinstance(module, BertSelfAttention):
q_w = module.query.weight.data
k_w = module.key.weight.data
v_w = module.value.weight.data
q_b = module.query.bias.data
k_b = module.key.bias.data
v_b = module.value.bias.data
new_w = torch.cat([q_w, k_w, v_w], dim=0)
new_b = torch.cat([q_b, k_b, v_b], dim=-1)
qkv = torch.nn.Linear(0, 0, bias=True)
qkv.weight = torch.nn.Parameter(new_w, requires_grad=False)
qkv.bias = torch.nn.Parameter(new_b, requires_grad=False)
qkv.in_features = module.query.in_features
qkv.out_features = module.query.out_features * 3
module.qkv = qkv
del module.query
del module.key
del module.value
def self_attention_forward(
self: torch.nn.Module,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
):
invalidInputError(encoder_hidden_states is None,
"cross attention is not supported")
invalidInputError(not self.is_decoder,
"bert decoder is not supported")
invalidInputError(self.position_embedding_type == "absolute",
"relative query/key is not supported")
qkv_output = self.qkv(hidden_states)
(query_layer, key_layer, value_layer) = torch.chunk(qkv_output, 3, -1)
query_layer = self.transpose_for_scores(query_layer)
key_layer = self.transpose_for_scores(key_layer)
value_layer = self.transpose_for_scores(value_layer)
if past_key_value is not None:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size),
key_layer.transpose(-1, -2))
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
def encoder_forward(
self: torch.nn.Module,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
):
if not attention_mask.any():
attention_mask = None
return BertEncoder.forward(
self=self,
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

View file

@ -353,10 +353,7 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
not training
and not x.requires_grad
and device in ["arc", "flex", "pvc", "mtl"] # fused layer norm cannot run on UHD
and (
device == "mtl" # fused layer norm conflicts with XMX, so disable it when using XMX
or x.numel() // x.size(-1) == 1
)
and x.numel() // x.size(-1) == 1 # fused layer norm is slower in first token
)