fix phi-3-vision import (#11129)
This commit is contained in:
parent
7f772c5a4f
commit
1dc680341b
1 changed files with 20 additions and 7 deletions
|
|
@ -37,18 +37,21 @@
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
import warnings
|
||||||
import transformers
|
import transformers
|
||||||
|
from typing import List
|
||||||
|
from unittest.mock import patch
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from .utils import extract_local_archive_file, \
|
from transformers.dynamic_module_utils import get_imports
|
||||||
load_state_dict, \
|
|
||||||
get_local_shard_files, load_imatrix_data
|
|
||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from ipex_llm.transformers.gguf.api import load_gguf_model
|
from ipex_llm.transformers.gguf.api import load_gguf_model
|
||||||
import torch
|
|
||||||
import warnings
|
from .utils import logger, load_state_dict
|
||||||
import copy
|
from .utils import extract_local_archive_file, get_local_shard_files, load_imatrix_data
|
||||||
from .utils import logger
|
|
||||||
|
|
||||||
patched_training_mode = None
|
patched_training_mode = None
|
||||||
|
|
||||||
|
|
@ -98,10 +101,19 @@ def _load_pre():
|
||||||
GPTJModel.__init__ = gptj_model_new_init
|
GPTJModel.__init__ = gptj_model_new_init
|
||||||
|
|
||||||
|
|
||||||
|
def patch_flash_attn_import(filename: str) -> List[str]:
|
||||||
|
"""Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
|
||||||
|
imports = get_imports(filename)
|
||||||
|
if "flash_attn" in imports:
|
||||||
|
imports.remove("flash_attn")
|
||||||
|
return imports
|
||||||
|
|
||||||
|
|
||||||
class _BaseAutoModelClass:
|
class _BaseAutoModelClass:
|
||||||
HF_MODEL = None
|
HF_MODEL = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||||
def from_pretrained(cls,
|
def from_pretrained(cls,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
@ -492,6 +504,7 @@ class _BaseAutoModelClass:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||||
def load_low_bit(cls,
|
def load_low_bit(cls,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
*model_args,
|
*model_args,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue