diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index e86c05b1..c6fe4847 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -40,6 +40,9 @@ from transformers import GenerationMixin original_generate = GenerationMixin.generate query_group_size = 16 +# may tune it with more tested data +PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD = 100 + @torch.no_grad() def generate( @@ -57,7 +60,14 @@ def generate( lookahead = kwargs.pop("lookahead", None) perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None) if perf_mode == "1" and lookahead is None: - lookahead = 2 # default to 2 now + if inputs is not None: + if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: + lookahead = 2 # default to 2 now + else: + inputs_embeds = kwargs.get("inputs_embeds", None) + if inputs_embeds is not None: + if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: + lookahead = 2 # default to 2 now if lookahead: from ipex_llm.transformers.convert import get_enable_ipex _enable_ipex = get_enable_ipex()