diff --git a/python/llm/dev/test/pep8-report.txt b/python/llm/dev/test/pep8-report.txt deleted file mode 100644 index e69de29b..00000000 diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 91e9d2f4..e7d63f38 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -736,6 +736,9 @@ def _optimize_pre(model, qtype=None): if model.config.model_type == "internlmxcomposer2": from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp model.apply(pre_process_attn_and_mlp) + if model.config.model_type == "gemma2": + from ipex_llm.transformers.models.gemma2 import merge_qkv + model.apply(merge_qkv) return model diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py new file mode 100644 index 00000000..d32e2ce4 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -0,0 +1,43 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from typing import List + + +def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear: + new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0) + if linears[0].bias is not None: + new_linear = torch.nn.Linear(0, 0, bias=True) + new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0) + new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False) + else: + new_linear = torch.nn.Linear(0, 0, bias=False) + new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False) + new_linear.in_features = new_weight.size(1) + new_linear.out_features = new_weight.size(0) + return new_linear + + +def merge_qkv_base(module: torch.nn.Module, attention_class): + if isinstance(module, attention_class): + qkv_proj = merge_linear([ + module.q_proj, + module.k_proj, + module.v_proj, + ]) + module.qkv_proj = qkv_proj + del module.q_proj, module.k_proj, module.v_proj diff --git a/python/llm/src/ipex_llm/transformers/models/gemma2.py b/python/llm/src/ipex_llm/transformers/models/gemma2.py index 27eea21f..719c758f 100644 --- a/python/llm/src/ipex_llm/transformers/models/gemma2.py +++ b/python/llm/src/ipex_llm/transformers/models/gemma2.py @@ -35,11 +35,17 @@ from typing import Optional, Tuple import torch from ipex_llm.utils.common import invalidInputError +from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.utils import should_use_fuse_rope from transformers.cache_utils import Cache +from transformers.models.gemma2.modeling_gemma2 import Gemma2Attention from transformers.models.gemma2.modeling_gemma2 import repeat_kv, apply_rotary_pos_emb +def merge_qkv(module: torch.nn.Module): + return merge_qkv_base(module, Gemma2Attention) + + def gemma2_attention_forward( self, hidden_states: torch.Tensor, @@ -52,16 +58,12 @@ def gemma2_attention_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], dim=1) # IPEX-LLM OPT: fuse rope if should_use_fuse_rope(hidden_states, position_ids, self.training):