Fix non-stop at eos token problem for lookup generation (#11896)

* Fix non-stop by eos_token_id problem for lookup

* Small fix

* Add judgement when generation_config.eos_token_id is None

* Fix based on comments
This commit is contained in:
Yuwen Hu 2024-08-22 18:55:59 +08:00 committed by GitHub
parent 794abe2ce8
commit 420ce7d164
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -266,6 +266,13 @@ def lookup_generate(self,
past_key_values = None
input_len = input_ids.shape[1]
eos_token_id_set = None
if generation_config.eos_token_id is not None:
if isinstance(generation_config.eos_token_id, list):
eos_token_id_set = set(generation_config.eos_token_id)
else:
eos_token_id_set = set([generation_config.eos_token_id])
while True:
if step >= max_new_tokens:
break
@ -381,11 +388,16 @@ def lookup_generate(self,
self.post_time.append(pot-mot)
# Stop on eos and remove content after eos
output_ids_list = output_ids[0].tolist()
if generation_config.eos_token_id in output_ids_list:
idx = output_ids_list.index(generation_config.eos_token_id)
step -= (len(output_ids_list) - idx - 1)
break
if eos_token_id_set is not None:
output_ids_list = output_ids[0].tolist()
first_eos_idx = -1
for out_idx, out_id in enumerate(output_ids_list):
if out_id in eos_token_id_set:
first_eos_idx = out_idx
break
if first_eos_idx > -1:
step -= (len(output_ids_list) - first_eos_idx - 1)
break
step = min(step, max_new_tokens)
e2e_toc = time.time()