Fix speculative decoding bug (#10855)
This commit is contained in:
parent
c9dee6cd0e
commit
92ea54b512
2 changed files with 8 additions and 3 deletions
|
|
@ -43,7 +43,7 @@ if __name__ == '__main__':
|
|||
help='Prompt to infer')
|
||||
parser.add_argument('--precision', type=str, default='bf16',
|
||||
help='Main model Precision')
|
||||
parser.add_argument('--n_predict', type=int, default=128,
|
||||
parser.add_argument('--n-predict', type=int, default=128,
|
||||
help='Max tokens to predict')
|
||||
parser.add_argument('--max-draft', type=int, default=8,
|
||||
help='Max draft')
|
||||
|
|
|
|||
|
|
@ -74,8 +74,7 @@ def generate(
|
|||
for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust',
|
||||
'auto_th_stop_draft', 'auto_parameters', 'min_step_draft',
|
||||
'th_batch_num']:
|
||||
value = kwargs.pop(var, None)
|
||||
del self.draft_model
|
||||
kwargs.pop(var, None)
|
||||
return original_generate(self,
|
||||
inputs=inputs,
|
||||
generation_config=generation_config,
|
||||
|
|
@ -100,6 +99,12 @@ def generate(
|
|||
draft_model=self.draft_model,
|
||||
**new_speculative_kwargs)
|
||||
else:
|
||||
# When `draft_model` is false, these attributes
|
||||
# related to speculative decoding should be removed
|
||||
for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust',
|
||||
'auto_th_stop_draft', 'auto_parameters', 'min_step_draft',
|
||||
'th_batch_num']:
|
||||
kwargs.pop(var, None)
|
||||
return original_generate(self,
|
||||
inputs=inputs,
|
||||
generation_config=generation_config,
|
||||
|
|
|
|||
Loading…
Reference in a new issue