add basic gemma2 optimization (#11672)
This commit is contained in:
parent
418640e466
commit
3e8819734b
3 changed files with 120 additions and 3 deletions
0
python/llm/dev/test/pep8-report.txt
Normal file
0
python/llm/dev/test/pep8-report.txt
Normal file
|
|
@ -1508,9 +1508,10 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
|
||||
convert_forward(model,
|
||||
module.GemmaRMSNorm,
|
||||
gemma_rms_norm_forward)
|
||||
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
|
||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention
|
||||
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
|
||||
convert_forward(model, Gemma2Attention, gemma2_attention_forward)
|
||||
elif model.config.model_type == "Yi":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
|
|||
116
python/llm/src/ipex_llm/transformers/models/gemma2.py
Normal file
116
python/llm/src/ipex_llm/transformers/models/gemma2.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.gemma2.modeling_gemma2 import repeat_kv, apply_rotary_pos_emb
|
||||
|
||||
|
||||
def gemma2_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
|
||||
# IPEX-LLM OPT: fuse rope
|
||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
cos, sin = None, None
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"sliding_window": self.sliding_window,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
|
||||
if self.config.attn_logit_softcapping is not None:
|
||||
attn_weights = attn_weights / self.config.attn_logit_softcapping
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * self.config.attn_logit_softcapping
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights,
|
||||
dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights,
|
||||
p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
Loading…
Reference in a new issue