Update IPEX_LLM_PERFORMANCE_MODE with input length threshold (#11908)

* Update IPEX_LLM_PERFORMANCE_MODE with input length threshold

* Update based on comments. And and judgement for inputs_embeds

* Fix for benchmarking purposes

* Update based on comments

* Small fix
This commit is contained in:
Yuwen Hu 2024-08-23 20:49:15 +08:00 committed by GitHub
parent 303a090a6b
commit 24c279e0ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -40,6 +40,9 @@ from transformers import GenerationMixin
original_generate = GenerationMixin.generate
query_group_size = 16
# may tune it with more tested data
PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD = 100
@torch.no_grad()
def generate(
@ -57,7 +60,14 @@ def generate(
lookahead = kwargs.pop("lookahead", None)
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
if perf_mode == "1" and lookahead is None:
lookahead = 2 # default to 2 now
if inputs is not None:
if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
lookahead = 2 # default to 2 now
else:
inputs_embeds = kwargs.get("inputs_embeds", None)
if inputs_embeds is not None:
if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
lookahead = 2 # default to 2 now
if lookahead:
from ipex_llm.transformers.convert import get_enable_ipex
_enable_ipex = get_enable_ipex()