Quick fix benchmark script (#11938)
This commit is contained in:
		
							parent
							
								
									b4b6ddf73c
								
							
						
					
					
						commit
						7f7f6c89f5
					
				
					 1 changed files with 2 additions and 2 deletions
				
			
		| 
						 | 
					@ -615,9 +615,9 @@ def transformers_int4_npu_win(repo_id,
 | 
				
			||||||
    # which convert the relevant layers in the model into INT4 format
 | 
					    # which convert the relevant layers in the model into INT4 format
 | 
				
			||||||
    st = time.perf_counter()
 | 
					    st = time.perf_counter()
 | 
				
			||||||
    if repo_id in CHATGLM_IDS:
 | 
					    if repo_id in CHATGLM_IDS:
 | 
				
			||||||
        model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True, torch_dtype=torch.float16,
 | 
					        model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True,
 | 
				
			||||||
                                          optimize_model=optimize_model, max_output_len=max_output_len, max_prompt_len=int(in_out_len[0]), transpose_value_cache=True,
 | 
					                                          optimize_model=optimize_model, max_output_len=max_output_len, max_prompt_len=int(in_out_len[0]), transpose_value_cache=True,
 | 
				
			||||||
                                          torch_dtype='auto', attn_implementation="eager").eval()
 | 
					                                          torch_dtype=torch.float16, attn_implementation="eager").eval()
 | 
				
			||||||
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
					        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
    elif repo_id in LLAMA_IDS:
 | 
					    elif repo_id in LLAMA_IDS:
 | 
				
			||||||
        model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True, torch_dtype=torch.float16,
 | 
					        model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True, torch_dtype=torch.float16,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue