diff --git a/python/llm/dev/test/pep8-report.txt b/python/llm/dev/test/pep8-report.txt new file mode 100644 index 00000000..e69de29b diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 936fab56..91e9d2f4 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1508,9 +1508,10 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward - convert_forward(model, - module.GemmaRMSNorm, - gemma_rms_norm_forward) + from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward + from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention + convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward) + convert_forward(model, Gemma2Attention, gemma2_attention_forward) elif model.config.model_type == "Yi": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/gemma2.py b/python/llm/src/ipex_llm/transformers/models/gemma2.py new file mode 100644 index 00000000..27eea21f --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/models/gemma2.py @@ -0,0 +1,116 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from ipex_llm.utils.common import invalidInputError +from ipex_llm.transformers.models.utils import should_use_fuse_rope +from transformers.cache_utils import Cache +from transformers.models.gemma2.modeling_gemma2 import repeat_kv, apply_rotary_pos_emb + + +def gemma2_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + # IPEX-LLM OPT: fuse rope + if should_use_fuse_rope(hidden_states, position_ids, self.training): + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) + cos, sin = None, None + else: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "sliding_window": self.sliding_window, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if self.config.attn_logit_softcapping is not None: + attn_weights = attn_weights / self.config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * self.config.attn_logit_softcapping + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, + dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = torch.nn.functional.dropout(attn_weights, + p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value