Fix non-stop at eos token problem for lookup generation (#11896)
* Fix non-stop by eos_token_id problem for lookup * Small fix * Add judgement when generation_config.eos_token_id is None * Fix based on comments
This commit is contained in:
parent
794abe2ce8
commit
420ce7d164
1 changed files with 17 additions and 5 deletions
|
|
@ -266,6 +266,13 @@ def lookup_generate(self,
|
|||
past_key_values = None
|
||||
input_len = input_ids.shape[1]
|
||||
|
||||
eos_token_id_set = None
|
||||
if generation_config.eos_token_id is not None:
|
||||
if isinstance(generation_config.eos_token_id, list):
|
||||
eos_token_id_set = set(generation_config.eos_token_id)
|
||||
else:
|
||||
eos_token_id_set = set([generation_config.eos_token_id])
|
||||
|
||||
while True:
|
||||
if step >= max_new_tokens:
|
||||
break
|
||||
|
|
@ -381,11 +388,16 @@ def lookup_generate(self,
|
|||
self.post_time.append(pot-mot)
|
||||
|
||||
# Stop on eos and remove content after eos
|
||||
output_ids_list = output_ids[0].tolist()
|
||||
if generation_config.eos_token_id in output_ids_list:
|
||||
idx = output_ids_list.index(generation_config.eos_token_id)
|
||||
step -= (len(output_ids_list) - idx - 1)
|
||||
break
|
||||
if eos_token_id_set is not None:
|
||||
output_ids_list = output_ids[0].tolist()
|
||||
first_eos_idx = -1
|
||||
for out_idx, out_id in enumerate(output_ids_list):
|
||||
if out_id in eos_token_id_set:
|
||||
first_eos_idx = out_idx
|
||||
break
|
||||
if first_eos_idx > -1:
|
||||
step -= (len(output_ids_list) - first_eos_idx - 1)
|
||||
break
|
||||
|
||||
step = min(step, max_new_tokens)
|
||||
e2e_toc = time.time()
|
||||
|
|
|
|||
Loading…
Reference in a new issue