[LLM] Add min_step_draft in speculative decoding (#10142)
* Fix gptj kvcache & position id * Add min_draft_tokens in speculative decoding * fix style * update
This commit is contained in:
parent
14ba2c5135
commit
23c91cdce6
1 changed files with 7 additions and 2 deletions
|
|
@ -58,7 +58,7 @@ def generate(
|
||||||
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
|
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
|
||||||
'top_k', 'top_p', 'temperature', 'hf_adjust',
|
'top_k', 'top_p', 'temperature', 'hf_adjust',
|
||||||
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty',
|
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty',
|
||||||
'attention_mask']:
|
'attention_mask', 'min_step_draft']:
|
||||||
value = kwargs.pop(var, None)
|
value = kwargs.pop(var, None)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
new_speculative_kwargs[var] = value
|
new_speculative_kwargs[var] = value
|
||||||
|
|
@ -331,11 +331,15 @@ def speculative_generate(self,
|
||||||
auto_th_stop_draft=True,
|
auto_th_stop_draft=True,
|
||||||
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
|
auto_parameters=[1, 0.5, 0.9, 1e-2, 0.9],
|
||||||
hf_adjust=False,
|
hf_adjust=False,
|
||||||
|
min_step_draft=3,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
**sampling_kwargs):
|
**sampling_kwargs):
|
||||||
invalidInputError(draft_model is not None,
|
invalidInputError(draft_model is not None,
|
||||||
"Draft model should be provided.")
|
"Draft model should be provided.")
|
||||||
|
# min_step_draft >= 1. Since the max_step_draft may adjust,
|
||||||
|
# min_step_draft can > max_step_draft
|
||||||
|
min_step_draft = min_step_draft if min_step_draft >= 1 else 1
|
||||||
|
|
||||||
if generation_config is None:
|
if generation_config is None:
|
||||||
generation_config = self.generation_config
|
generation_config = self.generation_config
|
||||||
|
|
@ -568,7 +572,8 @@ def speculative_generate(self,
|
||||||
# check if draft prob is less then th_stop_draft
|
# check if draft prob is less then th_stop_draft
|
||||||
# Draft number + step >= max output token number
|
# Draft number + step >= max output token number
|
||||||
th_random = 1 if random_probs is None else random_probs[step_draft]
|
th_random = 1 if random_probs is None else random_probs[step_draft]
|
||||||
if (draft_output_probs.item() < th_stop_draft and th_random > 0.3) or \
|
if (draft_output_probs.item() < th_stop_draft and th_random > 0.3 and
|
||||||
|
step_draft + 1 >= min_step_draft) or \
|
||||||
step + step_draft + 2 >= max_new_tokens:
|
step + step_draft + 2 >= max_new_tokens:
|
||||||
break
|
break
|
||||||
if self.device.type == 'xpu':
|
if self.device.type == 'xpu':
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue