LLM: Fix ChatGLM3 Speculative Example (#10236)

Fix ChatGLM3 Speculative Example.
This commit is contained in:
Xiangyu Tian 2024-02-26 10:57:28 +08:00 committed by GitHub
parent 0c6aef0f47
commit 85a99e13e8

View file

@ -57,20 +57,23 @@ if __name__ == '__main__':
load_in_low_bit="bf16", load_in_low_bit="bf16",
speculative=True, speculative=True,
trust_remote_code=True, trust_remote_code=True,
torchscript=True,
use_cache=True) use_cache=True)
model = model.to('cpu')
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
with torch.inference_mode(): with torch.inference_mode():
prompt = CHATGLM_V3_PROMPT_FORMAT.format(prompt=args.prompt) prompt = CHATGLM_V3_PROMPT_FORMAT.format(prompt=args.prompt)
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device) inputs = tokenizer(prompt, return_tensors='pt')
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
actual_in_len = input_ids.shape[1] actual_in_len = input_ids.shape[1]
print("actual input_ids length:" + str(actual_in_len)) print("actual input_ids length:" + str(actual_in_len))
# warmup # warmup
output = model.generate(input_ids, output = model.generate(input_ids,
max_new_tokens=args.n_predict, max_new_tokens=args.n_predict,
do_sample=False, do_sample=False,
attention_mask=attention_mask,
th_stop_draft=0.6) th_stop_draft=0.6)
output_str = tokenizer.decode(output[0]) output_str = tokenizer.decode(output[0])
@ -79,6 +82,7 @@ if __name__ == '__main__':
output = model.generate(input_ids, output = model.generate(input_ids,
max_new_tokens=args.n_predict, max_new_tokens=args.n_predict,
do_sample=False, do_sample=False,
attention_mask=attention_mask,
th_stop_draft=0.6) th_stop_draft=0.6)
output_str = tokenizer.decode(output[0], skip_special_tokens=True) output_str = tokenizer.decode(output[0], skip_special_tokens=True)
end = time.perf_counter() end = time.perf_counter()