Performance mode strategy update for input_embeds input (#11997)

This commit is contained in:
Yuwen Hu 2024-09-03 17:46:16 +08:00 committed by GitHub
parent 164f47adbd
commit 6eb55653ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -60,21 +60,24 @@ def generate(
lookahead = kwargs.pop("lookahead", None) lookahead = kwargs.pop("lookahead", None)
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None) perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
input_ids_shape = None input_tensor_shape = None
is_inputs_embeds = False
if inputs is not None: if inputs is not None:
input_ids_shape = inputs.shape input_tensor_shape = inputs.shape
else: else:
input_ids = kwargs.get("input_ids", None) input_ids = kwargs.get("input_ids", None)
if input_ids is not None: if input_ids is not None:
input_ids_shape = input_ids.shape input_tensor_shape = input_ids.shape
else: else:
inputs_embeds = kwargs.get("inputs_embeds", None) inputs_embeds = kwargs.get("inputs_embeds", None)
if inputs_embeds is not None: if inputs_embeds is not None:
input_ids_shape = inputs_embeds.shape is_inputs_embeds = True
input_tensor_shape = inputs_embeds.shape
if perf_mode == "1" and lookahead is None: if perf_mode == "1" and lookahead is None:
if input_ids_shape is not None and \ if input_tensor_shape is not None and \
input_ids_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: input_tensor_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD \
and not is_inputs_embeds:
lookahead = 2 # default to 2 now lookahead = 2 # default to 2 now
if lookahead: if lookahead:
@ -85,7 +88,7 @@ def generate(
logger.warning("Prompt lookup is currently not supported on CPU with IPEX, " logger.warning("Prompt lookup is currently not supported on CPU with IPEX, "
"fallback to original generate.") "fallback to original generate.")
kwargs.pop("max_matching_ngram_size", None) kwargs.pop("max_matching_ngram_size", None)
elif input_ids_shape is not None and input_ids_shape[0] > 1: elif input_tensor_shape is not None and input_tensor_shape[0] > 1:
logger.warning("Prompt lookup is currently not supported with batch inference, " logger.warning("Prompt lookup is currently not supported with batch inference, "
"fallback to original generate.") "fallback to original generate.")
kwargs.pop("max_matching_ngram_size", None) kwargs.pop("max_matching_ngram_size", None)