Fix dtype mismatch error (#10609)
* fix llama * fix * fix code style * add torch type in model.py --------- Co-authored-by: arda <arda@arda-arc19.sh.intel.com>
This commit is contained in:
		
							parent
							
								
									f37a1f2a81
								
							
						
					
					
						commit
						b4147a97bb
					
				
					 1 changed files with 9 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -295,6 +295,15 @@ class _BaseAutoModelClass:
 | 
			
		|||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    kwargs["torch_dtype"] = torch.float16
 | 
			
		||||
            elif load_in_low_bit == "bf16":
 | 
			
		||||
                if torch_dtype is not None and torch_dtype != torch.bfloat16:
 | 
			
		||||
                    invalidInputError(
 | 
			
		||||
                        False,
 | 
			
		||||
                        f"Please use torch_dtype=torch.bfloat16"
 | 
			
		||||
                        f" when setting load_in_low_bit='bf16'."
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    kwargs["torch_dtype"] = torch.bfloat16
 | 
			
		||||
            else:
 | 
			
		||||
                kwargs["torch_dtype"] = torch_dtype or "auto"
 | 
			
		||||
            # Avoid tensor parallel F.Linear Operations
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue