LLM: Fix ChatGLM3 Speculative Example (#10236)
Fix ChatGLM3 Speculative Example.
This commit is contained in:
parent
0c6aef0f47
commit
85a99e13e8
1 changed files with 6 additions and 2 deletions
|
|
@ -57,20 +57,23 @@ if __name__ == '__main__':
|
|||
load_in_low_bit="bf16",
|
||||
speculative=True,
|
||||
trust_remote_code=True,
|
||||
torchscript=True,
|
||||
use_cache=True)
|
||||
model = model.to('cpu')
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
with torch.inference_mode():
|
||||
prompt = CHATGLM_V3_PROMPT_FORMAT.format(prompt=args.prompt)
|
||||
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
|
||||
inputs = tokenizer(prompt, return_tensors='pt')
|
||||
input_ids = inputs.input_ids.to(model.device)
|
||||
attention_mask = inputs.attention_mask.to(model.device)
|
||||
actual_in_len = input_ids.shape[1]
|
||||
print("actual input_ids length:" + str(actual_in_len))
|
||||
# warmup
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict,
|
||||
do_sample=False,
|
||||
attention_mask=attention_mask,
|
||||
th_stop_draft=0.6)
|
||||
output_str = tokenizer.decode(output[0])
|
||||
|
||||
|
|
@ -79,6 +82,7 @@ if __name__ == '__main__':
|
|||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict,
|
||||
do_sample=False,
|
||||
attention_mask=attention_mask,
|
||||
th_stop_draft=0.6)
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
end = time.perf_counter()
|
||||
|
|
|
|||
Loading…
Reference in a new issue