From 6f999e6e9021ee2c1205121d2b6653d0cd7fea2d Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Mon, 29 Jul 2024 15:15:47 +0800 Subject: [PATCH] add sdp for gemma2 (#11677) --- .../llm/src/ipex_llm/transformers/convert.py | 3 + .../ipex_llm/transformers/models/gemma2.py | 103 ++++++++++++++---- .../src/ipex_llm/transformers/models/utils.py | 4 +- 3 files changed, 87 insertions(+), 23 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index e7d63f38..c12a78bd 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1512,9 +1512,12 @@ def _optimize_post(model, lightweight_bmm=False): module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward + from ipex_llm.transformers.models.gemma2 import gemma2_model_forward from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention + from transformers.models.gemma2.modeling_gemma2 import Gemma2Model convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward) convert_forward(model, Gemma2Attention, gemma2_attention_forward) + convert_forward(model, Gemma2Model, gemma2_model_forward) elif model.config.model_type == "Yi": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/gemma2.py b/python/llm/src/ipex_llm/transformers/models/gemma2.py index 719c758f..d6c3af52 100644 --- a/python/llm/src/ipex_llm/transformers/models/gemma2.py +++ b/python/llm/src/ipex_llm/transformers/models/gemma2.py @@ -31,14 +31,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple - import torch -from ipex_llm.utils.common import invalidInputError + +from typing import Optional, Tuple from ipex_llm.transformers.models.common import merge_qkv_base -from ipex_llm.transformers.models.utils import should_use_fuse_rope +from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal from transformers.cache_utils import Cache -from transformers.models.gemma2.modeling_gemma2 import Gemma2Attention +from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2Attention from transformers.models.gemma2.modeling_gemma2 import repeat_kv, apply_rotary_pos_emb @@ -46,6 +45,46 @@ def merge_qkv(module: torch.nn.Module): return merge_qkv_base(module, Gemma2Attention) +def gemma2_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +): + # ipex-llm change start: add kv_seq_len in past_key_values + if past_key_values is not None: + if cache_position is not None: + kv_seq_len = cache_position[-1].item() + 1 + else: + if input_ids is not None: + kv_seq_len = input_ids.size(1) + else: + kv_seq_len = inputs_embeds.size(1) + past_key_values.kv_seq_len = kv_seq_len + # ipex-llm change end + + return Gemma2Model.forward( + self=self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position + ) + + def gemma2_attention_forward( self, hidden_states: torch.Tensor, @@ -86,26 +125,48 @@ def gemma2_attention_forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + # IPEX_LLM OPT: sdp + kv_seq_len = q_len if past_key_value is None else past_key_value.kv_seq_len + if (use_sdp_causal(q_len, kv_seq_len, -1, query_states, self.training) + and kv_seq_len <= key_states.size(2)): + import xe_addons + attn_weights = None + attn_output = xe_addons.gemma2_sdp_causal(query_states, + key_states[:, :, :kv_seq_len, :], + value_states[:, :, :kv_seq_len, :], + attention_mask[:, :, :q_len, :kv_seq_len], + self.config.attn_logit_softcapping, + self.scaling) + elif use_sdp(q_len, kv_seq_len, -1, query_states): + import xe_addons + attn_weights = None + attn_output = xe_addons.gemma2_sdp(query_states, + key_states[:, :, :kv_seq_len, :], + value_states[:, :, :kv_seq_len, :], + attention_mask[:, :, :q_len, :kv_seq_len], + self.config.attn_logit_softcapping, + self.scaling) + else: + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - if self.config.attn_logit_softcapping is not None: - attn_weights = attn_weights / self.config.attn_logit_softcapping - attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * self.config.attn_logit_softcapping + if self.config.attn_logit_softcapping is not None: + attn_weights = attn_weights / self.config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * self.config.attn_logit_softcapping - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, - dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = torch.nn.functional.dropout(attn_weights, - p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 63c71f50..c4626bc9 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -329,7 +329,7 @@ def use_sdp(q_len, kv_len, head_dim, query_states): return ( query_states.device.type == "xpu" and query_states.dtype in [torch.float, torch.half] # fp32/fp16 - and head_dim in [64, 80, 96, 128] + and head_dim in [-1, 64, 80, 96, 128] and q_len != kv_len # next token and q_len <= 32 # lookup ) @@ -347,7 +347,7 @@ def use_sdp_fp8(q_len, kv_len, query_states): def use_sdp_causal(q_len, kv_len, head_dim, query_states, training): return ( q_len == kv_len # first token - and head_dim in [64, 80, 96, 128] # for now + and head_dim in [-1, 64, 80, 96, 128] # for now and query_states.device.type == "xpu" # GPU and query_states.dtype in [torch.float, torch.half] # fp32/fp16 and not query_states.requires_grad and not training # not training