Fix spculative llama3 no stop error (#10963)

* fix normal

* add eos_tokens_id on sp and add list if

* update

* no none
This commit is contained in:
Wang, Jian4 2024-05-08 17:09:47 +08:00 committed by GitHub
parent 02870dc385
commit 3209d6b057
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 29 additions and 10 deletions

View file

@ -54,13 +54,13 @@ def get_prompt(user_input: str, chat_history: list[tuple[str, str]],
prompt_texts = [f'<|begin_of_text|>'] prompt_texts = [f'<|begin_of_text|>']
if system_prompt != '': if system_prompt != '':
prompt_texts.append(f'<|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|>') prompt_texts.append(f'<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>')
for history_input, history_response in chat_history: for history_input, history_response in chat_history:
prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n{history_input.strip()}<|eot_id|>') prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{history_input.strip()}<|eot_id|>')
prompt_texts.append(f'<|start_header_id|>assistant<|end_header_id|>\n{history_response.strip()}<|eot_id|>') prompt_texts.append(f'<|start_header_id|>assistant<|end_header_id|>\n\n{history_response.strip()}<|eot_id|>')
prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n{user_input.strip()}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n') prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{user_input.strip()}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n')
return ''.join(prompt_texts) return ''.join(prompt_texts)
if __name__ == '__main__': if __name__ == '__main__':
@ -89,7 +89,13 @@ if __name__ == '__main__':
# Load tokenizer # Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# here the terminators refer to https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct#transformers-automodelforcausallm
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
# Generate predicted tokens # Generate predicted tokens
with torch.inference_mode(): with torch.inference_mode():
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
@ -102,6 +108,7 @@ if __name__ == '__main__':
# warmup # warmup
output = model.generate(input_ids, output = model.generate(input_ids,
max_new_tokens=args.n_predict, max_new_tokens=args.n_predict,
eos_token_id=terminators,
attention_mask=attention_mask, attention_mask=attention_mask,
do_sample=False) do_sample=False)
output_str = tokenizer.decode(output[0]) output_str = tokenizer.decode(output[0])
@ -110,6 +117,7 @@ if __name__ == '__main__':
st = time.perf_counter() st = time.perf_counter()
output = model.generate(input_ids, output = model.generate(input_ids,
max_new_tokens=args.n_predict, max_new_tokens=args.n_predict,
eos_token_id=terminators,
attention_mask=attention_mask, attention_mask=attention_mask,
do_sample=False) do_sample=False)
output_str = tokenizer.decode(output[0], skip_special_tokens=True) output_str = tokenizer.decode(output[0], skip_special_tokens=True)

View file

@ -91,7 +91,7 @@ def generate(
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample', for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
'top_k', 'top_p', 'temperature', 'hf_adjust', 'top_k', 'top_p', 'temperature', 'hf_adjust',
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty', 'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty',
'attention_mask', 'min_step_draft']: 'attention_mask', 'min_step_draft', 'eos_token_id']:
value = kwargs.pop(var, None) value = kwargs.pop(var, None)
if value is not None: if value is not None:
new_speculative_kwargs[var] = value new_speculative_kwargs[var] = value
@ -719,6 +719,7 @@ def speculative_generate(self,
# Step 4. (b, c, e) match (b, c, d) -> b, c # Step 4. (b, c, e) match (b, c, d) -> b, c
# Final, f will be the next input, just like a # Final, f will be the next input, just like a
# Step 5. Final-> b, c, f # Step 5. Final-> b, c, f
this_peer_finished = False
while True: while True:
if step >= max_new_tokens: if step >= max_new_tokens:
break break
@ -1093,10 +1094,20 @@ def speculative_generate(self,
# Stop on eos and remove content after eos # Stop on eos and remove content after eos
output_ids_list = output_ids[0].tolist() output_ids_list = output_ids[0].tolist()
if generation_config.eos_token_id in output_ids_list: if generation_config.eos_token_id is not None:
idx = output_ids_list.index(generation_config.eos_token_id) if isinstance(generation_config.eos_token_id, int):
step -= (len(output_ids_list) - idx - 1) eos_token_ids = [generation_config.eos_token_id]
break else:
eos_token_ids = generation_config.eos_token_id
for eos_token_id in eos_token_ids:
if eos_token_id in output_ids_list:
idx = output_ids_list.index(eos_token_id)
step -= (len(output_ids_list) - idx - 1)
this_peer_finished = True
break
if this_peer_finished:
break
if streamer is not None: if streamer is not None:
streamer.end() streamer.end()
step = min(step, max_new_tokens) step = min(step, max_new_tokens)