Fix speculative decoding bug (#10855)

This commit is contained in:
ZehuaCao 2024-04-23 14:28:31 +08:00 committed by GitHub
parent c9dee6cd0e
commit 92ea54b512
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 3 deletions

View file

@ -43,7 +43,7 @@ if __name__ == '__main__':
help='Prompt to infer') help='Prompt to infer')
parser.add_argument('--precision', type=str, default='bf16', parser.add_argument('--precision', type=str, default='bf16',
help='Main model Precision') help='Main model Precision')
parser.add_argument('--n_predict', type=int, default=128, parser.add_argument('--n-predict', type=int, default=128,
help='Max tokens to predict') help='Max tokens to predict')
parser.add_argument('--max-draft', type=int, default=8, parser.add_argument('--max-draft', type=int, default=8,
help='Max draft') help='Max draft')

View file

@ -74,8 +74,7 @@ def generate(
for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust', for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust',
'auto_th_stop_draft', 'auto_parameters', 'min_step_draft', 'auto_th_stop_draft', 'auto_parameters', 'min_step_draft',
'th_batch_num']: 'th_batch_num']:
value = kwargs.pop(var, None) kwargs.pop(var, None)
del self.draft_model
return original_generate(self, return original_generate(self,
inputs=inputs, inputs=inputs,
generation_config=generation_config, generation_config=generation_config,
@ -100,6 +99,12 @@ def generate(
draft_model=self.draft_model, draft_model=self.draft_model,
**new_speculative_kwargs) **new_speculative_kwargs)
else: else:
# When `draft_model` is false, these attributes
# related to speculative decoding should be removed
for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust',
'auto_th_stop_draft', 'auto_parameters', 'min_step_draft',
'th_batch_num']:
kwargs.pop(var, None)
return original_generate(self, return original_generate(self,
inputs=inputs, inputs=inputs,
generation_config=generation_config, generation_config=generation_config,