diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index 00e7a2f3..72153f4c 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -37,18 +37,21 @@ # SOFTWARE. # +import copy +import torch +import warnings import transformers +from typing import List +from unittest.mock import patch from transformers.configuration_utils import PretrainedConfig -from .utils import extract_local_archive_file, \ - load_state_dict, \ - get_local_shard_files, load_imatrix_data +from transformers.dynamic_module_utils import get_imports + from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.gguf.api import load_gguf_model -import torch -import warnings -import copy -from .utils import logger + +from .utils import logger, load_state_dict +from .utils import extract_local_archive_file, get_local_shard_files, load_imatrix_data patched_training_mode = None @@ -98,10 +101,19 @@ def _load_pre(): GPTJModel.__init__ = gptj_model_new_init +def patch_flash_attn_import(filename: str) -> List[str]: + """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72.""" + imports = get_imports(filename) + if "flash_attn" in imports: + imports.remove("flash_attn") + return imports + + class _BaseAutoModelClass: HF_MODEL = None @classmethod + @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) def from_pretrained(cls, *args, **kwargs): @@ -492,6 +504,7 @@ class _BaseAutoModelClass: return model @classmethod + @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) def load_low_bit(cls, pretrained_model_name_or_path, *model_args,