LLM: fix torch_dtype setting of apply fp16 optimization through optimize_model (#10556)
This commit is contained in:
		
							parent
							
								
									ea4bc450c4
								
							
						
					
					
						commit
						fc8c7904f0
					
				
					 1 changed files with 10 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -237,9 +237,19 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
 | 
			
		|||
        warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
 | 
			
		||||
                      " please use cpu_embedding instead.", FutureWarning)
 | 
			
		||||
        cpu_embedding = True
 | 
			
		||||
    if low_bit == "fp16":
 | 
			
		||||
        torch_dtype = kwargs.get("torch_dtype", None)
 | 
			
		||||
        if torch_dtype is not None and torch_dtype != torch.float16:
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              "Please use torch_dtype=torch.float16 when setting low_bit='fp16'.")
 | 
			
		||||
        else:
 | 
			
		||||
            torch_dtype = torch.float16
 | 
			
		||||
    else:
 | 
			
		||||
        torch_dtype = kwargs.get("torch_dtype", "auto")
 | 
			
		||||
    qtype = ggml_tensor_qtype[low_bit]
 | 
			
		||||
    model = ggml_convert_low_bit(model,
 | 
			
		||||
                                 qtype=qtype,
 | 
			
		||||
                                 torch_dtype=torch_dtype,
 | 
			
		||||
                                 optimize_model=optimize_llm,
 | 
			
		||||
                                 modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
                                 cpu_embedding=cpu_embedding,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue