Fix PR validation (#12253)

This commit is contained in:
Yishuo Wang 2024-10-23 18:10:47 +08:00 committed by GitHub
parent b685cf4349
commit cacc891962
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View file

@ -44,7 +44,6 @@ import transformers
from typing import List from typing import List
from unittest.mock import patch from unittest.mock import patch
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.dynamic_module_utils import get_imports
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
@ -115,7 +114,7 @@ class _BaseAutoModelClass:
@classmethod @classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
@patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available) @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
def from_pretrained(cls, def from_pretrained(cls,
*args, *args,
**kwargs): **kwargs):
@ -543,7 +542,7 @@ class _BaseAutoModelClass:
@classmethod @classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
@patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available) @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
def load_low_bit(cls, def load_low_bit(cls,
pretrained_model_name_or_path, pretrained_model_name_or_path,
*model_args, *model_args,

View file

@ -16,6 +16,7 @@
# #
from typing import List from typing import List
from transformers.dynamic_module_utils import get_imports
def patch_flash_attn_import(filename: str) -> List[str]: def patch_flash_attn_import(filename: str) -> List[str]: