From 6eb55653bae82e917af9cfea260550109c96352f Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Tue, 3 Sep 2024 17:46:16 +0800 Subject: [PATCH] Performance mode strategy update for input_embeds input (#11997) --- python/llm/src/ipex_llm/transformers/lookup.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index 60680faf..c5fe81d4 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -60,21 +60,24 @@ def generate( lookahead = kwargs.pop("lookahead", None) perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None) - input_ids_shape = None + input_tensor_shape = None + is_inputs_embeds = False if inputs is not None: - input_ids_shape = inputs.shape + input_tensor_shape = inputs.shape else: input_ids = kwargs.get("input_ids", None) if input_ids is not None: - input_ids_shape = input_ids.shape + input_tensor_shape = input_ids.shape else: inputs_embeds = kwargs.get("inputs_embeds", None) if inputs_embeds is not None: - input_ids_shape = inputs_embeds.shape + is_inputs_embeds = True + input_tensor_shape = inputs_embeds.shape if perf_mode == "1" and lookahead is None: - if input_ids_shape is not None and \ - input_ids_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: + if input_tensor_shape is not None and \ + input_tensor_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD \ + and not is_inputs_embeds: lookahead = 2 # default to 2 now if lookahead: @@ -85,7 +88,7 @@ def generate( logger.warning("Prompt lookup is currently not supported on CPU with IPEX, " "fallback to original generate.") kwargs.pop("max_matching_ngram_size", None) - elif input_ids_shape is not None and input_ids_shape[0] > 1: + elif input_tensor_shape is not None and input_tensor_shape[0] > 1: logger.warning("Prompt lookup is currently not supported with batch inference, " "fallback to original generate.") kwargs.pop("max_matching_ngram_size", None)