From 23c91cdce6bdd025be105bd1cc70f985f44dbff8 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Mon, 19 Feb 2024 14:31:41 +0800 Subject: [PATCH] [LLM] Add min_step_draft in speculative decoding (#10142) * Fix gptj kvcache & position id * Add min_draft_tokens in speculative decoding * fix style * update --- python/llm/src/bigdl/llm/transformers/speculative.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/speculative.py b/python/llm/src/bigdl/llm/transformers/speculative.py index 103a6237..9dc028e9 100644 --- a/python/llm/src/bigdl/llm/transformers/speculative.py +++ b/python/llm/src/bigdl/llm/transformers/speculative.py @@ -58,7 +58,7 @@ def generate( for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample', 'top_k', 'top_p', 'temperature', 'hf_adjust', 'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty', - 'attention_mask']: + 'attention_mask', 'min_step_draft']: value = kwargs.pop(var, None) if value is not None: new_speculative_kwargs[var] = value @@ -331,11 +331,15 @@ def speculative_generate(self, auto_th_stop_draft=True, auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9], hf_adjust=False, + min_step_draft=3, generation_config: Optional[GenerationConfig] = None, attention_mask=None, **sampling_kwargs): invalidInputError(draft_model is not None, "Draft model should be provided.") + # min_step_draft >= 1. Since the max_step_draft may adjust, + # min_step_draft can > max_step_draft + min_step_draft = min_step_draft if min_step_draft >= 1 else 1 if generation_config is None: generation_config = self.generation_config @@ -568,7 +572,8 @@ def speculative_generate(self, # check if draft prob is less then th_stop_draft # Draft number + step >= max output token number th_random = 1 if random_probs is None else random_probs[step_draft] - if (draft_output_probs.item() < th_stop_draft and th_random > 0.3) or \ + if (draft_output_probs.item() < th_stop_draft and th_random > 0.3 and + step_draft + 1 >= min_step_draft) or \ step + step_draft + 2 >= max_new_tokens: break if self.device.type == 'xpu':