[LLM] Support MLP optimization for Qwen1.5 (#10123)

This commit is contained in:
Yuwen Hu 2024-02-08 09:15:34 +08:00 committed by GitHub
parent 81ed65fbe7
commit a8450fc300
3 changed files with 187 additions and 2 deletions

View file

@ -889,6 +889,24 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, convert_forward(model,
module.QWenMLP, module.QWenMLP,
qwen_mlp_forward) qwen_mlp_forward)
elif model.config.model_type == "qwen2":
# for Qwen1.5-7B
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward
# TODO: add these optimization back
# RMSNorm and rotray embedding are disabled for now
# as they lead to obvious performance drop for Qwen 1.5
# convert_forward(model,
# module.Qwen2Attention,
# qwen2_attention_forward
# )
# convert_forward(model,
# module.Qwen2RMSNorm,
# llama_rms_norm_forward)
convert_forward(model,
module.Qwen2MLP,
llama_mlp_forward)
elif model.config.model_type == "aquila": elif model.config.model_type == "aquila":
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)

View file

@ -0,0 +1,167 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/qwen2/modeling_qwen2.py
# which is licensed under Apache License 2.0:
#
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
import warnings
import torch
import torch.nn as nn
from bigdl.llm.transformers.models.llama import repeat_kv
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.utils.common import invalidInputError
def should_use_fuse_rope(self, query_states, position_ids):
use_fuse_rope = query_states.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
use_fuse_rope = use_fuse_rope and position_ids is not None
return use_fuse_rope
def qwen2_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
"Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = \
key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = \
value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(
False,
"The cache structure has changed since version v4.36. "
f"If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, "
"please make sure to initialize the attention class with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"qwen2")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "qwen2")
if past_key_value is not None:
if use_fuse_rope:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states,
value_states,
self.layer_idx,
cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, "
f"but is {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
f"but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = \
nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights,
p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(
False,
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value

View file

@ -143,7 +143,7 @@ def rotate_every_two(x):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral", if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
"mixtral"]: "mixtral", "qwen2"]:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
@ -171,7 +171,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device) q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral", if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
"mixtral"]: "mixtral", "qwen2"]:
linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, q_embed, k_embed) linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, q_embed, k_embed)
return q_embed, k_embed return q_embed, k_embed
else: else: