From cacc891962c7fce909d92867b0343eec32314865 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 23 Oct 2024 18:10:47 +0800 Subject: [PATCH] Fix PR validation (#12253) --- python/llm/src/ipex_llm/transformers/model.py | 5 ++--- python/llm/src/ipex_llm/transformers/patches.py | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index 432a1d0e..f81ee840 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -44,7 +44,6 @@ import transformers from typing import List from unittest.mock import patch from transformers.configuration_utils import PretrainedConfig -from transformers.dynamic_module_utils import get_imports from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype from ipex_llm.utils.common import invalidInputError @@ -115,7 +114,7 @@ class _BaseAutoModelClass: @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) - @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available) + @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True) def from_pretrained(cls, *args, **kwargs): @@ -543,7 +542,7 @@ class _BaseAutoModelClass: @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) - @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available) + @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True) def load_low_bit(cls, pretrained_model_name_or_path, *model_args, diff --git a/python/llm/src/ipex_llm/transformers/patches.py b/python/llm/src/ipex_llm/transformers/patches.py index e7339104..f115ffa5 100644 --- a/python/llm/src/ipex_llm/transformers/patches.py +++ b/python/llm/src/ipex_llm/transformers/patches.py @@ -16,6 +16,7 @@ # from typing import List +from transformers.dynamic_module_utils import get_imports def patch_flash_attn_import(filename: str) -> List[str]: