From 3209d6b0576595c528322cdfad5a218f52862946 Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Wed, 8 May 2024 17:09:47 +0800 Subject: [PATCH] Fix spculative llama3 no stop error (#10963) * fix normal * add eos_tokens_id on sp and add list if * update * no none --- .../llama3/speculative.py | 18 +++++++++++----- .../src/ipex_llm/transformers/speculative.py | 21 ++++++++++++++----- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/python/llm/example/CPU/Speculative-Decoding/llama3/speculative.py b/python/llm/example/CPU/Speculative-Decoding/llama3/speculative.py index f7eeb917..efb7215e 100644 --- a/python/llm/example/CPU/Speculative-Decoding/llama3/speculative.py +++ b/python/llm/example/CPU/Speculative-Decoding/llama3/speculative.py @@ -54,13 +54,13 @@ def get_prompt(user_input: str, chat_history: list[tuple[str, str]], prompt_texts = [f'<|begin_of_text|>'] if system_prompt != '': - prompt_texts.append(f'<|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|>') + prompt_texts.append(f'<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>') for history_input, history_response in chat_history: - prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n{history_input.strip()}<|eot_id|>') - prompt_texts.append(f'<|start_header_id|>assistant<|end_header_id|>\n{history_response.strip()}<|eot_id|>') + prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{history_input.strip()}<|eot_id|>') + prompt_texts.append(f'<|start_header_id|>assistant<|end_header_id|>\n\n{history_response.strip()}<|eot_id|>') - prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n{user_input.strip()}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n') + prompt_texts.append(f'<|start_header_id|>user<|end_header_id|>\n\n{user_input.strip()}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n') return ''.join(prompt_texts) if __name__ == '__main__': @@ -89,7 +89,13 @@ if __name__ == '__main__': # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - + + # here the terminators refer to https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct#transformers-automodelforcausallm + terminators = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids("<|eot_id|>"), + ] + # Generate predicted tokens with torch.inference_mode(): prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) @@ -102,6 +108,7 @@ if __name__ == '__main__': # warmup output = model.generate(input_ids, max_new_tokens=args.n_predict, + eos_token_id=terminators, attention_mask=attention_mask, do_sample=False) output_str = tokenizer.decode(output[0]) @@ -110,6 +117,7 @@ if __name__ == '__main__': st = time.perf_counter() output = model.generate(input_ids, max_new_tokens=args.n_predict, + eos_token_id=terminators, attention_mask=attention_mask, do_sample=False) output_str = tokenizer.decode(output[0], skip_special_tokens=True) diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 2128998c..2f123659 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -91,7 +91,7 @@ def generate( for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample', 'top_k', 'top_p', 'temperature', 'hf_adjust', 'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty', - 'attention_mask', 'min_step_draft']: + 'attention_mask', 'min_step_draft', 'eos_token_id']: value = kwargs.pop(var, None) if value is not None: new_speculative_kwargs[var] = value @@ -719,6 +719,7 @@ def speculative_generate(self, # Step 4. (b, c, e) match (b, c, d) -> b, c # Final, f will be the next input, just like a # Step 5. Final-> b, c, f + this_peer_finished = False while True: if step >= max_new_tokens: break @@ -1093,10 +1094,20 @@ def speculative_generate(self, # Stop on eos and remove content after eos output_ids_list = output_ids[0].tolist() - if generation_config.eos_token_id in output_ids_list: - idx = output_ids_list.index(generation_config.eos_token_id) - step -= (len(output_ids_list) - idx - 1) - break + if generation_config.eos_token_id is not None: + if isinstance(generation_config.eos_token_id, int): + eos_token_ids = [generation_config.eos_token_id] + else: + eos_token_ids = generation_config.eos_token_id + + for eos_token_id in eos_token_ids: + if eos_token_id in output_ids_list: + idx = output_ids_list.index(eos_token_id) + step -= (len(output_ids_list) - idx - 1) + this_peer_finished = True + break + if this_peer_finished: + break if streamer is not None: streamer.end() step = min(step, max_new_tokens)