LLM: Add length check for IPEX-CPU speculative decoding (#10529)

Add length check for IPEX-CPU speculative decoding.
This commit is contained in:
Xiangyu Tian 2024-03-26 17:47:10 +08:00 committed by GitHub
parent a3b007f3b1
commit 11550d3f25
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -53,6 +53,28 @@ def generate(
**kwargs,
):
if hasattr(self, "draft_model"):
from ipex_llm.llm.transformers.convert import get_enable_ipex
_enable_ipex = get_enable_ipex()
if _enable_ipex and inputs.size(1) < 256:
logger.warning(
"IPEX_CPU optimized models have issues for speculative decoding with short prompts"
"(length < 256). Using normal generate() method instead."
)
for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust',
'auto_th_stop_draft', 'auto_parameters', 'min_step_draft',
'th_batch_num']:
value = kwargs.pop(var, None)
del self.draft_model
return original_generate(self,
inputs=inputs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
**kwargs)
# do speculative decoding
# TODO: maybe add other way to double check
new_speculative_kwargs = {}