Fix PR validation (#12253)
This commit is contained in:
		
							parent
							
								
									b685cf4349
								
							
						
					
					
						commit
						cacc891962
					
				
					 2 changed files with 3 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -44,7 +44,6 @@ import transformers
 | 
			
		|||
from typing import List
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
from transformers.dynamic_module_utils import get_imports
 | 
			
		||||
 | 
			
		||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
| 
						 | 
				
			
			@ -115,7 +114,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
 | 
			
		||||
    @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available)
 | 
			
		||||
    @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
 | 
			
		||||
    def from_pretrained(cls,
 | 
			
		||||
                        *args,
 | 
			
		||||
                        **kwargs):
 | 
			
		||||
| 
						 | 
				
			
			@ -543,7 +542,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
 | 
			
		||||
    @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available)
 | 
			
		||||
    @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
 | 
			
		||||
    def load_low_bit(cls,
 | 
			
		||||
                     pretrained_model_name_or_path,
 | 
			
		||||
                     *model_args,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,6 +16,7 @@
 | 
			
		|||
#
 | 
			
		||||
 | 
			
		||||
from typing import List
 | 
			
		||||
from transformers.dynamic_module_utils import get_imports
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_flash_attn_import(filename: str) -> List[str]:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue