Fix spculative llama3 no stop error (#10963)
* fix normal * add eos_tokens_id on sp and add list if * update * no none
This commit is contained in:
parent
02870dc385
commit
3209d6b057
2 changed files with 29 additions and 10 deletions
|
|
@ -54,13 +54,13 @@ def get_prompt(user_input: str, chat_history: list[tuple[str, str]],
|
|||
prompt_texts = [f'<|begin_of_text|>']
|
||||
|
||||
if system_prompt != '':
|
||||
prompt_texts.append(f'<|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|>')
|
||||
prompt_texts.append(f'<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>')
|
||||
|
||||
for history_input, history_response in chat_history:
|
||||
prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n{history_input.strip()}<|eot_id|>')
|
||||
prompt_texts.append(f'<|start_header_id|>assistant<|end_header_id|>\n{history_response.strip()}<|eot_id|>')
|
||||
prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{history_input.strip()}<|eot_id|>')
|
||||
prompt_texts.append(f'<|start_header_id|>assistant<|end_header_id|>\n\n{history_response.strip()}<|eot_id|>')
|
||||
|
||||
prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n{user_input.strip()}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n')
|
||||
prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{user_input.strip()}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n')
|
||||
return ''.join(prompt_texts)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
@ -89,7 +89,13 @@ if __name__ == '__main__':
|
|||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
|
||||
# here the terminators refer to https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct#transformers-automodelforcausallm
|
||||
terminators = [
|
||||
tokenizer.eos_token_id,
|
||||
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
|
||||
]
|
||||
|
||||
# Generate predicted tokens
|
||||
with torch.inference_mode():
|
||||
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
|
||||
|
|
@ -102,6 +108,7 @@ if __name__ == '__main__':
|
|||
# warmup
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict,
|
||||
eos_token_id=terminators,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False)
|
||||
output_str = tokenizer.decode(output[0])
|
||||
|
|
@ -110,6 +117,7 @@ if __name__ == '__main__':
|
|||
st = time.perf_counter()
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict,
|
||||
eos_token_id=terminators,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False)
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ def generate(
|
|||
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
|
||||
'top_k', 'top_p', 'temperature', 'hf_adjust',
|
||||
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty',
|
||||
'attention_mask', 'min_step_draft']:
|
||||
'attention_mask', 'min_step_draft', 'eos_token_id']:
|
||||
value = kwargs.pop(var, None)
|
||||
if value is not None:
|
||||
new_speculative_kwargs[var] = value
|
||||
|
|
@ -719,6 +719,7 @@ def speculative_generate(self,
|
|||
# Step 4. (b, c, e) match (b, c, d) -> b, c
|
||||
# Final, f will be the next input, just like a
|
||||
# Step 5. Final-> b, c, f
|
||||
this_peer_finished = False
|
||||
while True:
|
||||
if step >= max_new_tokens:
|
||||
break
|
||||
|
|
@ -1093,10 +1094,20 @@ def speculative_generate(self,
|
|||
|
||||
# Stop on eos and remove content after eos
|
||||
output_ids_list = output_ids[0].tolist()
|
||||
if generation_config.eos_token_id in output_ids_list:
|
||||
idx = output_ids_list.index(generation_config.eos_token_id)
|
||||
step -= (len(output_ids_list) - idx - 1)
|
||||
break
|
||||
if generation_config.eos_token_id is not None:
|
||||
if isinstance(generation_config.eos_token_id, int):
|
||||
eos_token_ids = [generation_config.eos_token_id]
|
||||
else:
|
||||
eos_token_ids = generation_config.eos_token_id
|
||||
|
||||
for eos_token_id in eos_token_ids:
|
||||
if eos_token_id in output_ids_list:
|
||||
idx = output_ids_list.index(eos_token_id)
|
||||
step -= (len(output_ids_list) - idx - 1)
|
||||
this_peer_finished = True
|
||||
break
|
||||
if this_peer_finished:
|
||||
break
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
step = min(step, max_new_tokens)
|
||||
|
|
|
|||
Loading…
Reference in a new issue