[LLM] Add min_step_draft in speculative decoding (#10142)
* Fix gptj kvcache & position id * Add min_draft_tokens in speculative decoding * fix style * update
This commit is contained in:
parent
14ba2c5135
commit
23c91cdce6
1 changed files with 7 additions and 2 deletions
|
|
@ -58,7 +58,7 @@ def generate(
|
|||
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
|
||||
'top_k', 'top_p', 'temperature', 'hf_adjust',
|
||||
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty',
|
||||
'attention_mask']:
|
||||
'attention_mask', 'min_step_draft']:
|
||||
value = kwargs.pop(var, None)
|
||||
if value is not None:
|
||||
new_speculative_kwargs[var] = value
|
||||
|
|
@ -331,11 +331,15 @@ def speculative_generate(self,
|
|||
auto_th_stop_draft=True,
|
||||
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
|
||||
hf_adjust=False,
|
||||
min_step_draft=3,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
attention_mask=None,
|
||||
**sampling_kwargs):
|
||||
invalidInputError(draft_model is not None,
|
||||
"Draft model should be provided.")
|
||||
# min_step_draft >= 1. Since the max_step_draft may adjust,
|
||||
# min_step_draft can > max_step_draft
|
||||
min_step_draft = min_step_draft if min_step_draft >= 1 else 1
|
||||
|
||||
if generation_config is None:
|
||||
generation_config = self.generation_config
|
||||
|
|
@ -568,7 +572,8 @@ def speculative_generate(self,
|
|||
# check if draft prob is less then th_stop_draft
|
||||
# Draft number + step >= max output token number
|
||||
th_random = 1 if random_probs is None else random_probs[step_draft]
|
||||
if (draft_output_probs.item() < th_stop_draft and th_random > 0.3) or \
|
||||
if (draft_output_probs.item() < th_stop_draft and th_random > 0.3 and
|
||||
step_draft + 1 >= min_step_draft) or \
|
||||
step + step_draft + 2 >= max_new_tokens:
|
||||
break
|
||||
if self.device.type == 'xpu':
|
||||
|
|
|
|||
Loading…
Reference in a new issue