Fix PR validation (#12253)
This commit is contained in:
parent
b685cf4349
commit
cacc891962
2 changed files with 3 additions and 3 deletions
|
|
@ -44,7 +44,6 @@ import transformers
|
||||||
from typing import List
|
from typing import List
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.dynamic_module_utils import get_imports
|
|
||||||
|
|
||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
|
@ -115,7 +114,7 @@ class _BaseAutoModelClass:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||||
@patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available)
|
@patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
|
||||||
def from_pretrained(cls,
|
def from_pretrained(cls,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
@ -543,7 +542,7 @@ class _BaseAutoModelClass:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||||
@patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available)
|
@patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
|
||||||
def load_low_bit(cls,
|
def load_low_bit(cls,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
*model_args,
|
*model_args,
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from transformers.dynamic_module_utils import get_imports
|
||||||
|
|
||||||
|
|
||||||
def patch_flash_attn_import(filename: str) -> List[str]:
|
def patch_flash_attn_import(filename: str) -> List[str]:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue