diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index b5a2feda..9ac0c9a6 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -368,6 +368,10 @@ def simple_generate( eos = 0xffffffff else: eos = new_generate_kwargs["eos_token_id"] + + if not isinstance(eos, list): + eos = [eos] + output_tokens = [] from .npu_llm_cpp import run_decode, run_prefill, reset @@ -379,7 +383,7 @@ def simple_generate( time_t2 = time.perf_counter() output_tokens.append(torch.tensor([token])) for i in range(new_tokens - 1): - if token == eos: + if token in eos: break token = run_decode(self.model_ptr, token, self.vocab_size) if streamer is not None: