From 578aef245d4a2ff2fe7262fb2f116afc622b2590 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 23 Oct 2024 15:33:45 +0800 Subject: [PATCH] Fix models auto choose SdpaAttention with ipex 2.3 (#12252) --- python/llm/src/ipex_llm/transformers/model.py | 11 ++----- .../llm/src/ipex_llm/transformers/patches.py | 30 +++++++++++++++++++ 2 files changed, 33 insertions(+), 8 deletions(-) create mode 100644 python/llm/src/ipex_llm/transformers/patches.py diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index 97513e5c..432a1d0e 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -52,6 +52,7 @@ from ipex_llm.transformers.gguf.api import load_gguf_model from .utils import logger, load_state_dict from .utils import extract_local_archive_file, get_local_shard_files, load_imatrix_data +from .patches import patch_flash_attn_import, patch_sdpa_available patched_training_mode = None @@ -109,19 +110,12 @@ def _load_pre(): GPTJModel.__init__ = gptj_model_new_init -def patch_flash_attn_import(filename: str) -> List[str]: - """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72.""" - imports = get_imports(filename) - if "flash_attn" in imports: - imports.remove("flash_attn") - return imports - - class _BaseAutoModelClass: HF_MODEL = None @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) + @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available) def from_pretrained(cls, *args, **kwargs): @@ -549,6 +543,7 @@ class _BaseAutoModelClass: @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) + @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available) def load_low_bit(cls, pretrained_model_name_or_path, *model_args, diff --git a/python/llm/src/ipex_llm/transformers/patches.py b/python/llm/src/ipex_llm/transformers/patches.py new file mode 100644 index 00000000..e7339104 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/patches.py @@ -0,0 +1,30 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# + +from typing import List + + +def patch_flash_attn_import(filename: str) -> List[str]: + """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72.""" + imports = get_imports(filename) + if "flash_attn" in imports: + imports.remove("flash_attn") + return imports + + +def patch_sdpa_available() -> bool: + return False