add optimize model option (#9530)
This commit is contained in:
		
							parent
							
								
									6bec0faea5
								
							
						
					
					
						commit
						45820cf3b9
					
				
					 2 changed files with 5 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -46,9 +46,10 @@ class BigDLLM(BaseLM):
 | 
			
		|||
        tokenizer=None,
 | 
			
		||||
        batch_size=1,
 | 
			
		||||
        load_in_8bit: Optional[bool] = False,
 | 
			
		||||
        trust_remote_code: Optional[bool] = False,
 | 
			
		||||
        trust_remote_code: Optional[bool] = True,
 | 
			
		||||
        load_in_low_bit=None,
 | 
			
		||||
        dtype: Optional[Union[str, torch.dtype]] = "auto",
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -58,8 +59,8 @@ class BigDLLM(BaseLM):
 | 
			
		|||
            import intel_extension_for_pytorch as ipex
 | 
			
		||||
        model = AutoModelForCausalLM.from_pretrained(pretrained,
 | 
			
		||||
                                          load_in_low_bit=load_in_low_bit,
 | 
			
		||||
                                          optimize_model=True,
 | 
			
		||||
                                          trust_remote_code=True,
 | 
			
		||||
                                          optimize_model=kwargs.get('optimize_model', True),
 | 
			
		||||
                                          trust_remote_code=trust_remote_code,
 | 
			
		||||
                                          use_cache=True,
 | 
			
		||||
                                          torch_dtype=_get_dtype(dtype))
 | 
			
		||||
        print(model) # print model to check precision
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -90,7 +90,7 @@ def main():
 | 
			
		|||
        prec_arg = parse_precision(prec, args.model)
 | 
			
		||||
        model_args = f"pretrained={args.pretrained},{prec_arg}"
 | 
			
		||||
        if len(args.model_args) > 0:
 | 
			
		||||
            model_args += args.model_args
 | 
			
		||||
            model_args = f"{model_args},{args.model_args}"
 | 
			
		||||
        for task in args.tasks:
 | 
			
		||||
            task_names=task_map.get(task, task).split(',')
 | 
			
		||||
            num_fewshot = task_to_n_few_shots.get(task, args.num_fewshot)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue