Small fix for NPU Python cpp simple generate regarding eos tokens (#12501)

This commit is contained in:
Yuwen Hu 2024-12-04 18:54:06 +08:00 committed by GitHub
parent d8b14a6305
commit 84f1c4ad57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -368,6 +368,10 @@ def simple_generate(
eos = 0xffffffff
else:
eos = new_generate_kwargs["eos_token_id"]
if not isinstance(eos, list):
eos = [eos]
output_tokens = []
from .npu_llm_cpp import run_decode, run_prefill, reset
@ -379,7 +383,7 @@ def simple_generate(
time_t2 = time.perf_counter()
output_tokens.append(torch.tensor([token]))
for i in range(new_tokens - 1):
if token == eos:
if token in eos:
break
token = run_decode(self.model_ptr, token, self.vocab_size)
if streamer is not None: