diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 74866fee..6581cb9b 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -53,6 +53,28 @@ def generate( **kwargs, ): if hasattr(self, "draft_model"): + from ipex_llm.llm.transformers.convert import get_enable_ipex + _enable_ipex = get_enable_ipex() + if _enable_ipex and inputs.size(1) < 256: + logger.warning( + "IPEX_CPU optimized models have issues for speculative decoding with short prompts" + "(length < 256). Using normal generate() method instead." + ) + for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust', + 'auto_th_stop_draft', 'auto_parameters', 'min_step_draft', + 'th_batch_num']: + value = kwargs.pop(var, None) + del self.draft_model + return original_generate(self, + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + **kwargs) # do speculative decoding # TODO: maybe add other way to double check new_speculative_kwargs = {}