LLM: Fix ChatGLM3 Speculative Example (#10236)
Fix ChatGLM3 Speculative Example.
This commit is contained in:
parent
0c6aef0f47
commit
85a99e13e8
1 changed files with 6 additions and 2 deletions
|
|
@ -57,20 +57,23 @@ if __name__ == '__main__':
|
||||||
load_in_low_bit="bf16",
|
load_in_low_bit="bf16",
|
||||||
speculative=True,
|
speculative=True,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
torchscript=True,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
model = model.to('cpu')
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
prompt = CHATGLM_V3_PROMPT_FORMAT.format(prompt=args.prompt)
|
prompt = CHATGLM_V3_PROMPT_FORMAT.format(prompt=args.prompt)
|
||||||
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
|
inputs = tokenizer(prompt, return_tensors='pt')
|
||||||
|
input_ids = inputs.input_ids.to(model.device)
|
||||||
|
attention_mask = inputs.attention_mask.to(model.device)
|
||||||
actual_in_len = input_ids.shape[1]
|
actual_in_len = input_ids.shape[1]
|
||||||
print("actual input_ids length:" + str(actual_in_len))
|
print("actual input_ids length:" + str(actual_in_len))
|
||||||
# warmup
|
# warmup
|
||||||
output = model.generate(input_ids,
|
output = model.generate(input_ids,
|
||||||
max_new_tokens=args.n_predict,
|
max_new_tokens=args.n_predict,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
|
attention_mask=attention_mask,
|
||||||
th_stop_draft=0.6)
|
th_stop_draft=0.6)
|
||||||
output_str = tokenizer.decode(output[0])
|
output_str = tokenizer.decode(output[0])
|
||||||
|
|
||||||
|
|
@ -79,6 +82,7 @@ if __name__ == '__main__':
|
||||||
output = model.generate(input_ids,
|
output = model.generate(input_ids,
|
||||||
max_new_tokens=args.n_predict,
|
max_new_tokens=args.n_predict,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
|
attention_mask=attention_mask,
|
||||||
th_stop_draft=0.6)
|
th_stop_draft=0.6)
|
||||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue