diff --git a/python/llm/example/CPU/Speculative-Decoding/baichuan2/speculative.py b/python/llm/example/CPU/Speculative-Decoding/baichuan2/speculative.py index fa59c410..947afd57 100644 --- a/python/llm/example/CPU/Speculative-Decoding/baichuan2/speculative.py +++ b/python/llm/example/CPU/Speculative-Decoding/baichuan2/speculative.py @@ -43,7 +43,7 @@ if __name__ == '__main__': help='Prompt to infer') parser.add_argument('--precision', type=str, default='bf16', help='Main model Precision') - parser.add_argument('--n_predict', type=int, default=128, + parser.add_argument('--n-predict', type=int, default=128, help='Max tokens to predict') parser.add_argument('--max-draft', type=int, default=8, help='Max draft') diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 1f7a84bd..31309451 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -74,8 +74,7 @@ def generate( for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust', 'auto_th_stop_draft', 'auto_parameters', 'min_step_draft', 'th_batch_num']: - value = kwargs.pop(var, None) - del self.draft_model + kwargs.pop(var, None) return original_generate(self, inputs=inputs, generation_config=generation_config, @@ -100,6 +99,12 @@ def generate( draft_model=self.draft_model, **new_speculative_kwargs) else: + # When `draft_model` is false, these attributes + # related to speculative decoding should be removed + for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust', + 'auto_th_stop_draft', 'auto_parameters', 'min_step_draft', + 'th_batch_num']: + kwargs.pop(var, None) return original_generate(self, inputs=inputs, generation_config=generation_config,