Fix non-stop at eos token problem for lookup generation (#11896)
* Fix non-stop by eos_token_id problem for lookup * Small fix * Add judgement when generation_config.eos_token_id is None * Fix based on comments
This commit is contained in:
parent
794abe2ce8
commit
420ce7d164
1 changed files with 17 additions and 5 deletions
|
|
@ -266,6 +266,13 @@ def lookup_generate(self,
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
input_len = input_ids.shape[1]
|
input_len = input_ids.shape[1]
|
||||||
|
|
||||||
|
eos_token_id_set = None
|
||||||
|
if generation_config.eos_token_id is not None:
|
||||||
|
if isinstance(generation_config.eos_token_id, list):
|
||||||
|
eos_token_id_set = set(generation_config.eos_token_id)
|
||||||
|
else:
|
||||||
|
eos_token_id_set = set([generation_config.eos_token_id])
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if step >= max_new_tokens:
|
if step >= max_new_tokens:
|
||||||
break
|
break
|
||||||
|
|
@ -381,10 +388,15 @@ def lookup_generate(self,
|
||||||
self.post_time.append(pot-mot)
|
self.post_time.append(pot-mot)
|
||||||
|
|
||||||
# Stop on eos and remove content after eos
|
# Stop on eos and remove content after eos
|
||||||
|
if eos_token_id_set is not None:
|
||||||
output_ids_list = output_ids[0].tolist()
|
output_ids_list = output_ids[0].tolist()
|
||||||
if generation_config.eos_token_id in output_ids_list:
|
first_eos_idx = -1
|
||||||
idx = output_ids_list.index(generation_config.eos_token_id)
|
for out_idx, out_id in enumerate(output_ids_list):
|
||||||
step -= (len(output_ids_list) - idx - 1)
|
if out_id in eos_token_id_set:
|
||||||
|
first_eos_idx = out_idx
|
||||||
|
break
|
||||||
|
if first_eos_idx > -1:
|
||||||
|
step -= (len(output_ids_list) - first_eos_idx - 1)
|
||||||
break
|
break
|
||||||
|
|
||||||
step = min(step, max_new_tokens)
|
step = min(step, max_new_tokens)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue