[LLM] Add min_step_draft in speculative decoding (#10142)

* Fix gptj kvcache & position id

* Add min_draft_tokens in speculative decoding

* fix style

* update
This commit is contained in:
Yina Chen 2024-02-19 14:31:41 +08:00 committed by GitHub
parent 14ba2c5135
commit 23c91cdce6

View file

@ -58,7 +58,7 @@ def generate(
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
'top_k', 'top_p', 'temperature', 'hf_adjust',
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty',
'attention_mask']:
'attention_mask', 'min_step_draft']:
value = kwargs.pop(var, None)
if value is not None:
new_speculative_kwargs[var] = value
@ -331,11 +331,15 @@ def speculative_generate(self,
auto_th_stop_draft=True,
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
hf_adjust=False,
min_step_draft=3,
generation_config: Optional[GenerationConfig] = None,
attention_mask=None,
**sampling_kwargs):
invalidInputError(draft_model is not None,
"Draft model should be provided.")
# min_step_draft >= 1. Since the max_step_draft may adjust,
# min_step_draft can > max_step_draft
min_step_draft = min_step_draft if min_step_draft >= 1 else 1
if generation_config is None:
generation_config = self.generation_config
@ -568,7 +572,8 @@ def speculative_generate(self,
# check if draft prob is less then th_stop_draft
# Draft number + step >= max output token number
th_random = 1 if random_probs is None else random_probs[step_draft]
if (draft_output_probs.item() < th_stop_draft and th_random > 0.3) or \
if (draft_output_probs.item() < th_stop_draft and th_random > 0.3 and
step_draft + 1 >= min_step_draft) or \
step + step_draft + 2 >= max_new_tokens:
break
if self.device.type == 'xpu':