From 420ce7d164cdc9c23ffdaf050f5bf54b52f20baf Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Thu, 22 Aug 2024 18:55:59 +0800 Subject: [PATCH] Fix non-stop at eos token problem for lookup generation (#11896) * Fix non-stop by eos_token_id problem for lookup * Small fix * Add judgement when generation_config.eos_token_id is None * Fix based on comments --- .../llm/src/ipex_llm/transformers/lookup.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index 00ec1a62..e86c05b1 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -266,6 +266,13 @@ def lookup_generate(self, past_key_values = None input_len = input_ids.shape[1] + eos_token_id_set = None + if generation_config.eos_token_id is not None: + if isinstance(generation_config.eos_token_id, list): + eos_token_id_set = set(generation_config.eos_token_id) + else: + eos_token_id_set = set([generation_config.eos_token_id]) + while True: if step >= max_new_tokens: break @@ -381,11 +388,16 @@ def lookup_generate(self, self.post_time.append(pot-mot) # Stop on eos and remove content after eos - output_ids_list = output_ids[0].tolist() - if generation_config.eos_token_id in output_ids_list: - idx = output_ids_list.index(generation_config.eos_token_id) - step -= (len(output_ids_list) - idx - 1) - break + if eos_token_id_set is not None: + output_ids_list = output_ids[0].tolist() + first_eos_idx = -1 + for out_idx, out_id in enumerate(output_ids_list): + if out_id in eos_token_id_set: + first_eos_idx = out_idx + break + if first_eos_idx > -1: + step -= (len(output_ids_list) - first_eos_idx - 1) + break step = min(step, max_new_tokens) e2e_toc = time.time()