Update IPEX_LLM_PERFORMANCE_MODE with input length threshold (#11908)
* Update IPEX_LLM_PERFORMANCE_MODE with input length threshold * Update based on comments. And and judgement for inputs_embeds * Fix for benchmarking purposes * Update based on comments * Small fix
This commit is contained in:
parent
303a090a6b
commit
24c279e0ae
1 changed files with 11 additions and 1 deletions
|
|
@ -40,6 +40,9 @@ from transformers import GenerationMixin
|
|||
original_generate = GenerationMixin.generate
|
||||
query_group_size = 16
|
||||
|
||||
# may tune it with more tested data
|
||||
PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD = 100
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
|
|
@ -57,6 +60,13 @@ def generate(
|
|||
lookahead = kwargs.pop("lookahead", None)
|
||||
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
|
||||
if perf_mode == "1" and lookahead is None:
|
||||
if inputs is not None:
|
||||
if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
|
||||
lookahead = 2 # default to 2 now
|
||||
else:
|
||||
inputs_embeds = kwargs.get("inputs_embeds", None)
|
||||
if inputs_embeds is not None:
|
||||
if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
|
||||
lookahead = 2 # default to 2 now
|
||||
if lookahead:
|
||||
from ipex_llm.transformers.convert import get_enable_ipex
|
||||
|
|
|
|||
Loading…
Reference in a new issue