LLM: Add length check for IPEX-CPU speculative decoding (#10529)
Add length check for IPEX-CPU speculative decoding.
This commit is contained in:
parent
a3b007f3b1
commit
11550d3f25
1 changed files with 22 additions and 0 deletions
|
|
@ -53,6 +53,28 @@ def generate(
|
|||
**kwargs,
|
||||
):
|
||||
if hasattr(self, "draft_model"):
|
||||
from ipex_llm.llm.transformers.convert import get_enable_ipex
|
||||
_enable_ipex = get_enable_ipex()
|
||||
if _enable_ipex and inputs.size(1) < 256:
|
||||
logger.warning(
|
||||
"IPEX_CPU optimized models have issues for speculative decoding with short prompts"
|
||||
"(length < 256). Using normal generate() method instead."
|
||||
)
|
||||
for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust',
|
||||
'auto_th_stop_draft', 'auto_parameters', 'min_step_draft',
|
||||
'th_batch_num']:
|
||||
value = kwargs.pop(var, None)
|
||||
del self.draft_model
|
||||
return original_generate(self,
|
||||
inputs=inputs,
|
||||
generation_config=generation_config,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
synced_gpus=synced_gpus,
|
||||
assistant_model=assistant_model,
|
||||
streamer=streamer,
|
||||
**kwargs)
|
||||
# do speculative decoding
|
||||
# TODO: maybe add other way to double check
|
||||
new_speculative_kwargs = {}
|
||||
|
|
|
|||
Loading…
Reference in a new issue