LLM: Add length check for IPEX-CPU speculative decoding (#10529)
Add length check for IPEX-CPU speculative decoding.
This commit is contained in:
parent
a3b007f3b1
commit
11550d3f25
1 changed files with 22 additions and 0 deletions
|
|
@ -53,6 +53,28 @@ def generate(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if hasattr(self, "draft_model"):
|
if hasattr(self, "draft_model"):
|
||||||
|
from ipex_llm.llm.transformers.convert import get_enable_ipex
|
||||||
|
_enable_ipex = get_enable_ipex()
|
||||||
|
if _enable_ipex and inputs.size(1) < 256:
|
||||||
|
logger.warning(
|
||||||
|
"IPEX_CPU optimized models have issues for speculative decoding with short prompts"
|
||||||
|
"(length < 256). Using normal generate() method instead."
|
||||||
|
)
|
||||||
|
for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust',
|
||||||
|
'auto_th_stop_draft', 'auto_parameters', 'min_step_draft',
|
||||||
|
'th_batch_num']:
|
||||||
|
value = kwargs.pop(var, None)
|
||||||
|
del self.draft_model
|
||||||
|
return original_generate(self,
|
||||||
|
inputs=inputs,
|
||||||
|
generation_config=generation_config,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
|
synced_gpus=synced_gpus,
|
||||||
|
assistant_model=assistant_model,
|
||||||
|
streamer=streamer,
|
||||||
|
**kwargs)
|
||||||
# do speculative decoding
|
# do speculative decoding
|
||||||
# TODO: maybe add other way to double check
|
# TODO: maybe add other way to double check
|
||||||
new_speculative_kwargs = {}
|
new_speculative_kwargs = {}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue