fix phi-3-vision import (#11129)
This commit is contained in:
		
							parent
							
								
									7f772c5a4f
								
							
						
					
					
						commit
						1dc680341b
					
				
					 1 changed files with 20 additions and 7 deletions
				
			
		| 
						 | 
					@ -37,18 +37,21 @@
 | 
				
			||||||
# SOFTWARE.
 | 
					# SOFTWARE.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import copy
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
import transformers
 | 
					import transformers
 | 
				
			||||||
 | 
					from typing import List
 | 
				
			||||||
 | 
					from unittest.mock import patch
 | 
				
			||||||
from transformers.configuration_utils import PretrainedConfig
 | 
					from transformers.configuration_utils import PretrainedConfig
 | 
				
			||||||
from .utils import extract_local_archive_file, \
 | 
					from transformers.dynamic_module_utils import get_imports
 | 
				
			||||||
    load_state_dict, \
 | 
					
 | 
				
			||||||
    get_local_shard_files, load_imatrix_data
 | 
					 | 
				
			||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
 | 
					from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
from ipex_llm.transformers.gguf.api import load_gguf_model
 | 
					from ipex_llm.transformers.gguf.api import load_gguf_model
 | 
				
			||||||
import torch
 | 
					
 | 
				
			||||||
import warnings
 | 
					from .utils import logger, load_state_dict
 | 
				
			||||||
import copy
 | 
					from .utils import extract_local_archive_file, get_local_shard_files, load_imatrix_data
 | 
				
			||||||
from .utils import logger
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
patched_training_mode = None
 | 
					patched_training_mode = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -98,10 +101,19 @@ def _load_pre():
 | 
				
			||||||
    GPTJModel.__init__ = gptj_model_new_init
 | 
					    GPTJModel.__init__ = gptj_model_new_init
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def patch_flash_attn_import(filename: str) -> List[str]:
 | 
				
			||||||
 | 
					    """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
 | 
				
			||||||
 | 
					    imports = get_imports(filename)
 | 
				
			||||||
 | 
					    if "flash_attn" in imports:
 | 
				
			||||||
 | 
					        imports.remove("flash_attn")
 | 
				
			||||||
 | 
					    return imports
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class _BaseAutoModelClass:
 | 
					class _BaseAutoModelClass:
 | 
				
			||||||
    HF_MODEL = None
 | 
					    HF_MODEL = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
 | 
				
			||||||
    def from_pretrained(cls,
 | 
					    def from_pretrained(cls,
 | 
				
			||||||
                        *args,
 | 
					                        *args,
 | 
				
			||||||
                        **kwargs):
 | 
					                        **kwargs):
 | 
				
			||||||
| 
						 | 
					@ -492,6 +504,7 @@ class _BaseAutoModelClass:
 | 
				
			||||||
        return model
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
 | 
				
			||||||
    def load_low_bit(cls,
 | 
					    def load_low_bit(cls,
 | 
				
			||||||
                     pretrained_model_name_or_path,
 | 
					                     pretrained_model_name_or_path,
 | 
				
			||||||
                     *model_args,
 | 
					                     *model_args,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue