Fix spculative llama3 no stop error (#10963)
* fix normal * add eos_tokens_id on sp and add list if * update * no none
This commit is contained in:
parent
02870dc385
commit
3209d6b057
2 changed files with 29 additions and 10 deletions
|
|
@ -54,13 +54,13 @@ def get_prompt(user_input: str, chat_history: list[tuple[str, str]],
|
||||||
prompt_texts = [f'<|begin_of_text|>']
|
prompt_texts = [f'<|begin_of_text|>']
|
||||||
|
|
||||||
if system_prompt != '':
|
if system_prompt != '':
|
||||||
prompt_texts.append(f'<|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|>')
|
prompt_texts.append(f'<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>')
|
||||||
|
|
||||||
for history_input, history_response in chat_history:
|
for history_input, history_response in chat_history:
|
||||||
prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n{history_input.strip()}<|eot_id|>')
|
prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{history_input.strip()}<|eot_id|>')
|
||||||
prompt_texts.append(f'<|start_header_id|>assistant<|end_header_id|>\n{history_response.strip()}<|eot_id|>')
|
prompt_texts.append(f'<|start_header_id|>assistant<|end_header_id|>\n\n{history_response.strip()}<|eot_id|>')
|
||||||
|
|
||||||
prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n{user_input.strip()}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n')
|
prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{user_input.strip()}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n')
|
||||||
return ''.join(prompt_texts)
|
return ''.join(prompt_texts)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
@ -89,7 +89,13 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
# here the terminators refer to https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct#transformers-automodelforcausallm
|
||||||
|
terminators = [
|
||||||
|
tokenizer.eos_token_id,
|
||||||
|
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
|
||||||
|
]
|
||||||
|
|
||||||
# Generate predicted tokens
|
# Generate predicted tokens
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
|
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
|
||||||
|
|
@ -102,6 +108,7 @@ if __name__ == '__main__':
|
||||||
# warmup
|
# warmup
|
||||||
output = model.generate(input_ids,
|
output = model.generate(input_ids,
|
||||||
max_new_tokens=args.n_predict,
|
max_new_tokens=args.n_predict,
|
||||||
|
eos_token_id=terminators,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
do_sample=False)
|
do_sample=False)
|
||||||
output_str = tokenizer.decode(output[0])
|
output_str = tokenizer.decode(output[0])
|
||||||
|
|
@ -110,6 +117,7 @@ if __name__ == '__main__':
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
output = model.generate(input_ids,
|
output = model.generate(input_ids,
|
||||||
max_new_tokens=args.n_predict,
|
max_new_tokens=args.n_predict,
|
||||||
|
eos_token_id=terminators,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
do_sample=False)
|
do_sample=False)
|
||||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,7 @@ def generate(
|
||||||
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
|
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
|
||||||
'top_k', 'top_p', 'temperature', 'hf_adjust',
|
'top_k', 'top_p', 'temperature', 'hf_adjust',
|
||||||
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty',
|
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty',
|
||||||
'attention_mask', 'min_step_draft']:
|
'attention_mask', 'min_step_draft', 'eos_token_id']:
|
||||||
value = kwargs.pop(var, None)
|
value = kwargs.pop(var, None)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
new_speculative_kwargs[var] = value
|
new_speculative_kwargs[var] = value
|
||||||
|
|
@ -719,6 +719,7 @@ def speculative_generate(self,
|
||||||
# Step 4. (b, c, e) match (b, c, d) -> b, c
|
# Step 4. (b, c, e) match (b, c, d) -> b, c
|
||||||
# Final, f will be the next input, just like a
|
# Final, f will be the next input, just like a
|
||||||
# Step 5. Final-> b, c, f
|
# Step 5. Final-> b, c, f
|
||||||
|
this_peer_finished = False
|
||||||
while True:
|
while True:
|
||||||
if step >= max_new_tokens:
|
if step >= max_new_tokens:
|
||||||
break
|
break
|
||||||
|
|
@ -1093,10 +1094,20 @@ def speculative_generate(self,
|
||||||
|
|
||||||
# Stop on eos and remove content after eos
|
# Stop on eos and remove content after eos
|
||||||
output_ids_list = output_ids[0].tolist()
|
output_ids_list = output_ids[0].tolist()
|
||||||
if generation_config.eos_token_id in output_ids_list:
|
if generation_config.eos_token_id is not None:
|
||||||
idx = output_ids_list.index(generation_config.eos_token_id)
|
if isinstance(generation_config.eos_token_id, int):
|
||||||
step -= (len(output_ids_list) - idx - 1)
|
eos_token_ids = [generation_config.eos_token_id]
|
||||||
break
|
else:
|
||||||
|
eos_token_ids = generation_config.eos_token_id
|
||||||
|
|
||||||
|
for eos_token_id in eos_token_ids:
|
||||||
|
if eos_token_id in output_ids_list:
|
||||||
|
idx = output_ids_list.index(eos_token_id)
|
||||||
|
step -= (len(output_ids_list) - idx - 1)
|
||||||
|
this_peer_finished = True
|
||||||
|
break
|
||||||
|
if this_peer_finished:
|
||||||
|
break
|
||||||
if streamer is not None:
|
if streamer is not None:
|
||||||
streamer.end()
|
streamer.end()
|
||||||
step = min(step, max_new_tokens)
|
step = min(step, max_new_tokens)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue