Update IPEX_LLM_PERFORMANCE_MODE with input length threshold (#11908)
* Update IPEX_LLM_PERFORMANCE_MODE with input length threshold * Update based on comments. And and judgement for inputs_embeds * Fix for benchmarking purposes * Update based on comments * Small fix
This commit is contained in:
parent
303a090a6b
commit
24c279e0ae
1 changed files with 11 additions and 1 deletions
|
|
@ -40,6 +40,9 @@ from transformers import GenerationMixin
|
||||||
original_generate = GenerationMixin.generate
|
original_generate = GenerationMixin.generate
|
||||||
query_group_size = 16
|
query_group_size = 16
|
||||||
|
|
||||||
|
# may tune it with more tested data
|
||||||
|
PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD = 100
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
|
|
@ -57,7 +60,14 @@ def generate(
|
||||||
lookahead = kwargs.pop("lookahead", None)
|
lookahead = kwargs.pop("lookahead", None)
|
||||||
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
|
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
|
||||||
if perf_mode == "1" and lookahead is None:
|
if perf_mode == "1" and lookahead is None:
|
||||||
lookahead = 2 # default to 2 now
|
if inputs is not None:
|
||||||
|
if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
|
||||||
|
lookahead = 2 # default to 2 now
|
||||||
|
else:
|
||||||
|
inputs_embeds = kwargs.get("inputs_embeds", None)
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
|
||||||
|
lookahead = 2 # default to 2 now
|
||||||
if lookahead:
|
if lookahead:
|
||||||
from ipex_llm.transformers.convert import get_enable_ipex
|
from ipex_llm.transformers.convert import get_enable_ipex
|
||||||
_enable_ipex = get_enable_ipex()
|
_enable_ipex = get_enable_ipex()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue