Fix qwen2 1.5B NPU load error (#12049)
This commit is contained in:
		
							parent
							
								
									abc370728c
								
							
						
					
					
						commit
						dc4af02b2a
					
				
					 1 changed files with 4 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -202,9 +202,6 @@ class _BaseAutoModelClass:
 | 
			
		|||
    @classmethod
 | 
			
		||||
    @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
 | 
			
		||||
    def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
 | 
			
		||||
        if kwargs.pop("torch_dtype", None) not in [None, "auto", torch.float]:
 | 
			
		||||
            warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
 | 
			
		||||
 | 
			
		||||
        # ignore following arguments
 | 
			
		||||
        ignore_argument(kwargs, "model_hub")
 | 
			
		||||
        ignore_argument(kwargs, "lightweight_bmm")
 | 
			
		||||
| 
						 | 
				
			
			@ -402,6 +399,10 @@ class _BaseAutoModelClass:
 | 
			
		|||
        if dtype_orig is not None:
 | 
			
		||||
            torch.set_default_dtype(dtype_orig)
 | 
			
		||||
 | 
			
		||||
        # set tie_word_embeddings to False to avoid possible lm_head error
 | 
			
		||||
        if hasattr(model.config, "tie_word_embeddings"):
 | 
			
		||||
            model.config.tie_word_embeddings = False
 | 
			
		||||
 | 
			
		||||
        (
 | 
			
		||||
            model,
 | 
			
		||||
            missing_keys,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue