From f3a2b20e6b43938bad9ad19a13e4f87114c484c3 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 24 Oct 2024 13:44:24 +0800 Subject: [PATCH] Optimize gpt2 (#12259) --- .../llm/src/ipex_llm/transformers/convert.py | 5 ++ .../src/ipex_llm/transformers/models/gpt2.py | 90 +++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 python/llm/src/ipex_llm/transformers/models/gpt2.py diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index fce3b7f7..00b5e001 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1549,6 +1549,11 @@ def _optimize_post(model, lightweight_bmm=False): module.BaichuanModel, "get_alibi_mask", baichuan_13b_get_alibi_mask) + elif model.config.model_type == "gpt2": + from ipex_llm.transformers.models.gpt2 import gpt2_attention_attn + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + module.GPT2Attention._attn = gpt2_attention_attn elif model.config.model_type == "gpt_neox": from ipex_llm.transformers.models.gptneox import gptneox_attention_forward convert_forward(model, diff --git a/python/llm/src/ipex_llm/transformers/models/gpt2.py b/python/llm/src/ipex_llm/transformers/models/gpt2.py new file mode 100644 index 00000000..1a968bd0 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/models/gpt2.py @@ -0,0 +1,90 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +from ipex_llm.transformers.models.utils import use_sdp_non_causal + + +def gpt2_attention_attn( + self, + query, + key, + value, + attention_mask=None, + head_mask=None +): + # ipex-llm changes start + if ( + self.scale_attn_weights + and not self.scale_attn_by_inverse_layer_idx + and head_mask is None + and query.size(-2) == key.size(-2) + and use_sdp_non_causal(query.size(-1), query.device, query.dtype) + ): + if not self.is_cross_attention: + seq_len = query.size(-2) + causal_mask = self.bias[:, :, :seq_len, :seq_len] + mask_value = torch.finfo(query.dtype).min + mask_value = torch.full([], mask_value, dtype=query.dtype, device=query.device) + attention_mask = attention_mask.expand(-1, -1, seq_len, seq_len) + attention_mask = torch.where(causal_mask, attention_mask, mask_value) + else: + attention_mask = attention_mask.expand(-1, -1, seq_len, seq_len) + + import xe_addons + attn_weights = None + attn_output = xe_addons.sdp_non_causal(query, key.contiguous(), + value.contiguous(), attention_mask) + return attn_output, attn_weights + # ipex-llm changes end + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, + device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights