add more gemma2 optimization (#11673)
This commit is contained in:
parent
3e8819734b
commit
7f88ce23cd
4 changed files with 58 additions and 10 deletions
|
|
@ -736,6 +736,9 @@ def _optimize_pre(model, qtype=None):
|
|||
if model.config.model_type == "internlmxcomposer2":
|
||||
from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
|
||||
model.apply(pre_process_attn_and_mlp)
|
||||
if model.config.model_type == "gemma2":
|
||||
from ipex_llm.transformers.models.gemma2 import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
|||
43
python/llm/src/ipex_llm/transformers/models/common.py
Normal file
43
python/llm/src/ipex_llm/transformers/models/common.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
|
||||
def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
|
||||
new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
|
||||
if linears[0].bias is not None:
|
||||
new_linear = torch.nn.Linear(0, 0, bias=True)
|
||||
new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0)
|
||||
new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||
else:
|
||||
new_linear = torch.nn.Linear(0, 0, bias=False)
|
||||
new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||
new_linear.in_features = new_weight.size(1)
|
||||
new_linear.out_features = new_weight.size(0)
|
||||
return new_linear
|
||||
|
||||
|
||||
def merge_qkv_base(module: torch.nn.Module, attention_class):
|
||||
if isinstance(module, attention_class):
|
||||
qkv_proj = merge_linear([
|
||||
module.q_proj,
|
||||
module.k_proj,
|
||||
module.v_proj,
|
||||
])
|
||||
module.qkv_proj = qkv_proj
|
||||
del module.q_proj, module.k_proj, module.v_proj
|
||||
|
|
@ -35,11 +35,17 @@ from typing import Optional, Tuple
|
|||
|
||||
import torch
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2Attention
|
||||
from transformers.models.gemma2.modeling_gemma2 import repeat_kv, apply_rotary_pos_emb
|
||||
|
||||
|
||||
def merge_qkv(module: torch.nn.Module):
|
||||
return merge_qkv_base(module, Gemma2Attention)
|
||||
|
||||
|
||||
def gemma2_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
@ -52,16 +58,12 @@ def gemma2_attention_forward(
|
|||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
||||
qkv = qkv.transpose(1, 2)
|
||||
query_states, key_states, value_states = qkv.split([self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
self.num_key_value_heads], dim=1)
|
||||
|
||||
# IPEX-LLM OPT: fuse rope
|
||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||
|
|
|
|||
Loading…
Reference in a new issue