Performance mode strategy update for input_embeds input (#11997)
This commit is contained in:
parent
164f47adbd
commit
6eb55653ba
1 changed files with 10 additions and 7 deletions
|
|
@ -60,21 +60,24 @@ def generate(
|
||||||
lookahead = kwargs.pop("lookahead", None)
|
lookahead = kwargs.pop("lookahead", None)
|
||||||
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
|
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
|
||||||
|
|
||||||
input_ids_shape = None
|
input_tensor_shape = None
|
||||||
|
is_inputs_embeds = False
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
input_ids_shape = inputs.shape
|
input_tensor_shape = inputs.shape
|
||||||
else:
|
else:
|
||||||
input_ids = kwargs.get("input_ids", None)
|
input_ids = kwargs.get("input_ids", None)
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
input_ids_shape = input_ids.shape
|
input_tensor_shape = input_ids.shape
|
||||||
else:
|
else:
|
||||||
inputs_embeds = kwargs.get("inputs_embeds", None)
|
inputs_embeds = kwargs.get("inputs_embeds", None)
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
input_ids_shape = inputs_embeds.shape
|
is_inputs_embeds = True
|
||||||
|
input_tensor_shape = inputs_embeds.shape
|
||||||
|
|
||||||
if perf_mode == "1" and lookahead is None:
|
if perf_mode == "1" and lookahead is None:
|
||||||
if input_ids_shape is not None and \
|
if input_tensor_shape is not None and \
|
||||||
input_ids_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
|
input_tensor_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD \
|
||||||
|
and not is_inputs_embeds:
|
||||||
lookahead = 2 # default to 2 now
|
lookahead = 2 # default to 2 now
|
||||||
|
|
||||||
if lookahead:
|
if lookahead:
|
||||||
|
|
@ -85,7 +88,7 @@ def generate(
|
||||||
logger.warning("Prompt lookup is currently not supported on CPU with IPEX, "
|
logger.warning("Prompt lookup is currently not supported on CPU with IPEX, "
|
||||||
"fallback to original generate.")
|
"fallback to original generate.")
|
||||||
kwargs.pop("max_matching_ngram_size", None)
|
kwargs.pop("max_matching_ngram_size", None)
|
||||||
elif input_ids_shape is not None and input_ids_shape[0] > 1:
|
elif input_tensor_shape is not None and input_tensor_shape[0] > 1:
|
||||||
logger.warning("Prompt lookup is currently not supported with batch inference, "
|
logger.warning("Prompt lookup is currently not supported with batch inference, "
|
||||||
"fallback to original generate.")
|
"fallback to original generate.")
|
||||||
kwargs.pop("max_matching_ngram_size", None)
|
kwargs.pop("max_matching_ngram_size", None)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue