LLM: quick fix benchmark (#9509)

This commit is contained in:
Ruonan Wang 2023-11-22 10:19:57 +08:00 committed by GitHub
parent c2aeb4d1e8
commit 139e98aa18

View file

@ -420,7 +420,7 @@ def run_optimize_model_gpu(repo_id,
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu') model = model.to('xpu')
elif repo_id in LLAMA_IDS: elif repo_id in LLAMA_IDS:
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_4bit=True, trust_remote_code=True, model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True,
use_cache=True, low_cpu_mem_usage=True) use_cache=True, low_cpu_mem_usage=True)
model = optimize_model(model, low_bit=low_bit) model = optimize_model(model, low_bit=low_bit)
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)