fix pad_token_id issue (#10425)
This commit is contained in:
parent
e25d7413de
commit
f3fefdc9ce
4 changed files with 5 additions and 1 deletions
|
|
@ -60,6 +60,7 @@ if __name__ == '__main__':
|
|||
# it is important to set `use_cache=True` explicitly in the `generate` function
|
||||
# to obtain optimal performance with BigDL-LLM INT4 optimizations
|
||||
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
# Note that phi-2 uses GenerationConfig to enable 'use_cache'
|
||||
output = model.generate(input_ids, do_sample=False, max_new_tokens=args.n_predict, generation_config = generation_config)
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ if __name__ == '__main__':
|
|||
with torch.inference_mode():
|
||||
prompt = PHI_2_V1_PROMPT_FORMAT.format(prompt=args.prompt)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
st = time.time()
|
||||
output = model.generate(input_ids, max_new_tokens=args.n_predict, generation_config = generation_config)
|
||||
end = time.time()
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ if __name__ == '__main__':
|
|||
prompt = PHI2_PROMPT_FORMAT.format(prompt=args.prompt)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu')
|
||||
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
# ipex model needs a warmup, then inference time can be accurate
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict,
|
||||
|
|
|
|||
|
|
@ -55,7 +55,8 @@ if __name__ == '__main__':
|
|||
with torch.inference_mode():
|
||||
prompt = PHI_2_V1_PROMPT_FORMAT.format(prompt=args.prompt)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu')
|
||||
|
||||
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
# ipex model needs a warmup, then inference time can be accurate
|
||||
output = model.generate(input_ids, do_sample=False, max_new_tokens=args.n_predict, generation_config = generation_config)
|
||||
# start inference
|
||||
|
|
|
|||
Loading…
Reference in a new issue