fix phi-3-vision import (#11129)

This commit is contained in:
Yishuo Wang 2024-05-24 15:57:15 +08:00 committed by GitHub
parent 7f772c5a4f
commit 1dc680341b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -37,18 +37,21 @@
# SOFTWARE. # SOFTWARE.
# #
import copy
import torch
import warnings
import transformers import transformers
from typing import List
from unittest.mock import patch
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from .utils import extract_local_archive_file, \ from transformers.dynamic_module_utils import get_imports
load_state_dict, \
get_local_shard_files, load_imatrix_data
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.gguf.api import load_gguf_model from ipex_llm.transformers.gguf.api import load_gguf_model
import torch
import warnings from .utils import logger, load_state_dict
import copy from .utils import extract_local_archive_file, get_local_shard_files, load_imatrix_data
from .utils import logger
patched_training_mode = None patched_training_mode = None
@ -98,10 +101,19 @@ def _load_pre():
GPTJModel.__init__ = gptj_model_new_init GPTJModel.__init__ = gptj_model_new_init
def patch_flash_attn_import(filename: str) -> List[str]:
"""Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
imports = get_imports(filename)
if "flash_attn" in imports:
imports.remove("flash_attn")
return imports
class _BaseAutoModelClass: class _BaseAutoModelClass:
HF_MODEL = None HF_MODEL = None
@classmethod @classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
def from_pretrained(cls, def from_pretrained(cls,
*args, *args,
**kwargs): **kwargs):
@ -492,6 +504,7 @@ class _BaseAutoModelClass:
return model return model
@classmethod @classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
def load_low_bit(cls, def load_low_bit(cls,
pretrained_model_name_or_path, pretrained_model_name_or_path,
*model_args, *model_args,