From 24c279e0ae9913bdf5bbcb244c6b4a5a6b366981 Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Fri, 23 Aug 2024 20:49:15 +0800 Subject: [PATCH] Update `IPEX_LLM_PERFORMANCE_MODE` with input length threshold (#11908) * Update IPEX_LLM_PERFORMANCE_MODE with input length threshold * Update based on comments. And and judgement for inputs_embeds * Fix for benchmarking purposes * Update based on comments * Small fix --- python/llm/src/ipex_llm/transformers/lookup.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index e86c05b1..c6fe4847 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -40,6 +40,9 @@ from transformers import GenerationMixin original_generate = GenerationMixin.generate query_group_size = 16 +# may tune it with more tested data +PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD = 100 + @torch.no_grad() def generate( @@ -57,7 +60,14 @@ def generate( lookahead = kwargs.pop("lookahead", None) perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None) if perf_mode == "1" and lookahead is None: - lookahead = 2 # default to 2 now + if inputs is not None: + if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: + lookahead = 2 # default to 2 now + else: + inputs_embeds = kwargs.get("inputs_embeds", None) + if inputs_embeds is not None: + if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: + lookahead = 2 # default to 2 now if lookahead: from ipex_llm.transformers.convert import get_enable_ipex _enable_ipex = get_enable_ipex()