Update IPEX_LLM_PERFORMANCE_MODE (#11823)
This commit is contained in:
parent
5a80fd2633
commit
9e9086cc2a
1 changed files with 11 additions and 7 deletions
|
|
@ -487,14 +487,18 @@ def update_past_key_value(past_key_value, key_states, value_states,
|
|||
|
||||
def should_use_compresskv(x: torch.Tensor, prompt_len: int):
|
||||
use_compress_kv = os.environ.get("IPEX_LLM_COMPRESS_KV_CACHE", None)
|
||||
if use_compress_kv is None:
|
||||
return (
|
||||
get_xpu_device_type(x) == "mtl"
|
||||
and prompt_len >= 1800
|
||||
and prompt_len <= 4500
|
||||
)
|
||||
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
|
||||
if perf_mode == "1":
|
||||
return False
|
||||
else:
|
||||
return x.device.type == 'xpu' and use_compress_kv == "1"
|
||||
if use_compress_kv is None:
|
||||
return (
|
||||
get_xpu_device_type(x) == "mtl"
|
||||
and prompt_len >= 1800
|
||||
and prompt_len <= 4500
|
||||
)
|
||||
else:
|
||||
return x.device.type == 'xpu' and use_compress_kv == "1"
|
||||
|
||||
|
||||
def get_compresskv_attn_mask(key_states: torch.Tensor,
|
||||
|
|
|
|||
Loading…
Reference in a new issue