diff --git a/python/llm/example/CPU/Speculative-Decoding/chatglm3/speculative.py b/python/llm/example/CPU/Speculative-Decoding/chatglm3/speculative.py index ba8e95db..b60b87b7 100644 --- a/python/llm/example/CPU/Speculative-Decoding/chatglm3/speculative.py +++ b/python/llm/example/CPU/Speculative-Decoding/chatglm3/speculative.py @@ -57,20 +57,23 @@ if __name__ == '__main__': load_in_low_bit="bf16", speculative=True, trust_remote_code=True, + torchscript=True, use_cache=True) - model = model.to('cpu') tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) with torch.inference_mode(): prompt = CHATGLM_V3_PROMPT_FORMAT.format(prompt=args.prompt) - input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device) + inputs = tokenizer(prompt, return_tensors='pt') + input_ids = inputs.input_ids.to(model.device) + attention_mask = inputs.attention_mask.to(model.device) actual_in_len = input_ids.shape[1] print("actual input_ids length:" + str(actual_in_len)) # warmup output = model.generate(input_ids, max_new_tokens=args.n_predict, do_sample=False, + attention_mask=attention_mask, th_stop_draft=0.6) output_str = tokenizer.decode(output[0]) @@ -79,6 +82,7 @@ if __name__ == '__main__': output = model.generate(input_ids, max_new_tokens=args.n_predict, do_sample=False, + attention_mask=attention_mask, th_stop_draft=0.6) output_str = tokenizer.decode(output[0], skip_special_tokens=True) end = time.perf_counter()