From 84f1c4ad57dd877a8f0f79bf10b35596ec511299 Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Wed, 4 Dec 2024 18:54:06 +0800 Subject: [PATCH] Small fix for NPU Python cpp simple generate regarding eos tokens (#12501) --- python/llm/src/ipex_llm/transformers/npu_models/convert.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index b5a2feda..9ac0c9a6 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -368,6 +368,10 @@ def simple_generate( eos = 0xffffffff else: eos = new_generate_kwargs["eos_token_id"] + + if not isinstance(eos, list): + eos = [eos] + output_tokens = [] from .npu_llm_cpp import run_decode, run_prefill, reset @@ -379,7 +383,7 @@ def simple_generate( time_t2 = time.perf_counter() output_tokens.append(torch.tensor([token])) for i in range(new_tokens - 1): - if token == eos: + if token in eos: break token = run_decode(self.model_ptr, token, self.vocab_size) if streamer is not None: