Fix speculative decoding bug (#10855)
This commit is contained in:
parent
c9dee6cd0e
commit
92ea54b512
2 changed files with 8 additions and 3 deletions
|
|
@ -43,7 +43,7 @@ if __name__ == '__main__':
|
||||||
help='Prompt to infer')
|
help='Prompt to infer')
|
||||||
parser.add_argument('--precision', type=str, default='bf16',
|
parser.add_argument('--precision', type=str, default='bf16',
|
||||||
help='Main model Precision')
|
help='Main model Precision')
|
||||||
parser.add_argument('--n_predict', type=int, default=128,
|
parser.add_argument('--n-predict', type=int, default=128,
|
||||||
help='Max tokens to predict')
|
help='Max tokens to predict')
|
||||||
parser.add_argument('--max-draft', type=int, default=8,
|
parser.add_argument('--max-draft', type=int, default=8,
|
||||||
help='Max draft')
|
help='Max draft')
|
||||||
|
|
|
||||||
|
|
@ -74,8 +74,7 @@ def generate(
|
||||||
for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust',
|
for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust',
|
||||||
'auto_th_stop_draft', 'auto_parameters', 'min_step_draft',
|
'auto_th_stop_draft', 'auto_parameters', 'min_step_draft',
|
||||||
'th_batch_num']:
|
'th_batch_num']:
|
||||||
value = kwargs.pop(var, None)
|
kwargs.pop(var, None)
|
||||||
del self.draft_model
|
|
||||||
return original_generate(self,
|
return original_generate(self,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
|
|
@ -100,6 +99,12 @@ def generate(
|
||||||
draft_model=self.draft_model,
|
draft_model=self.draft_model,
|
||||||
**new_speculative_kwargs)
|
**new_speculative_kwargs)
|
||||||
else:
|
else:
|
||||||
|
# When `draft_model` is false, these attributes
|
||||||
|
# related to speculative decoding should be removed
|
||||||
|
for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust',
|
||||||
|
'auto_th_stop_draft', 'auto_parameters', 'min_step_draft',
|
||||||
|
'th_batch_num']:
|
||||||
|
kwargs.pop(var, None)
|
||||||
return original_generate(self,
|
return original_generate(self,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue