Fix spculative llama3 no stop error (#10963)

* fix normal

* add eos_tokens_id on sp and add list if

* update

* no none
This commit is contained in:
Wang, Jian4 2024-05-08 17:09:47 +08:00 committed by GitHub
parent 02870dc385
commit 3209d6b057
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 29 additions and 10 deletions

View file

@ -54,13 +54,13 @@ def get_prompt(user_input: str, chat_history: list[tuple[str, str]],
prompt_texts = [f'<|begin_of_text|>']
if system_prompt != '':
prompt_texts.append(f'<|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|>')
prompt_texts.append(f'<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>')
for history_input, history_response in chat_history:
prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n{history_input.strip()}<|eot_id|>')
prompt_texts.append(f'<|start_header_id|>assistant<|end_header_id|>\n{history_response.strip()}<|eot_id|>')
prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{history_input.strip()}<|eot_id|>')
prompt_texts.append(f'<|start_header_id|>assistant<|end_header_id|>\n\n{history_response.strip()}<|eot_id|>')
prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n{user_input.strip()}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n')
prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{user_input.strip()}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n')
return ''.join(prompt_texts)
if __name__ == '__main__':
@ -89,7 +89,13 @@ if __name__ == '__main__':
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# here the terminators refer to https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct#transformers-automodelforcausallm
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
# Generate predicted tokens
with torch.inference_mode():
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
@ -102,6 +108,7 @@ if __name__ == '__main__':
# warmup
output = model.generate(input_ids,
max_new_tokens=args.n_predict,
eos_token_id=terminators,
attention_mask=attention_mask,
do_sample=False)
output_str = tokenizer.decode(output[0])
@ -110,6 +117,7 @@ if __name__ == '__main__':
st = time.perf_counter()
output = model.generate(input_ids,
max_new_tokens=args.n_predict,
eos_token_id=terminators,
attention_mask=attention_mask,
do_sample=False)
output_str = tokenizer.decode(output[0], skip_special_tokens=True)

View file

@ -91,7 +91,7 @@ def generate(
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
'top_k', 'top_p', 'temperature', 'hf_adjust',
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty',
'attention_mask', 'min_step_draft']:
'attention_mask', 'min_step_draft', 'eos_token_id']:
value = kwargs.pop(var, None)
if value is not None:
new_speculative_kwargs[var] = value
@ -719,6 +719,7 @@ def speculative_generate(self,
# Step 4. (b, c, e) match (b, c, d) -> b, c
# Final, f will be the next input, just like a
# Step 5. Final-> b, c, f
this_peer_finished = False
while True:
if step >= max_new_tokens:
break
@ -1093,10 +1094,20 @@ def speculative_generate(self,
# Stop on eos and remove content after eos
output_ids_list = output_ids[0].tolist()
if generation_config.eos_token_id in output_ids_list:
idx = output_ids_list.index(generation_config.eos_token_id)
step -= (len(output_ids_list) - idx - 1)
break
if generation_config.eos_token_id is not None:
if isinstance(generation_config.eos_token_id, int):
eos_token_ids = [generation_config.eos_token_id]
else:
eos_token_ids = generation_config.eos_token_id
for eos_token_id in eos_token_ids:
if eos_token_id in output_ids_list:
idx = output_ids_list.index(eos_token_id)
step -= (len(output_ids_list) - idx - 1)
this_peer_finished = True
break
if this_peer_finished:
break
if streamer is not None:
streamer.end()
step = min(step, max_new_tokens)