diff --git a/python/llm/src/ipex_llm/transformers/models/aquila.py b/python/llm/src/ipex_llm/transformers/models/aquila.py index 1b1d252a..02054dcc 100644 --- a/python/llm/src/ipex_llm/transformers/models/aquila.py +++ b/python/llm/src/ipex_llm/transformers/models/aquila.py @@ -48,7 +48,9 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from ipex_llm.utils.common import log4Error -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +import os + +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def aquila_attention_forward( diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index 0c9e8216..0fef9131 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -35,7 +35,9 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_ from ipex_llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +import os + +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def baichuan_attention_forward_7b( diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan2.py b/python/llm/src/ipex_llm/transformers/models/baichuan2.py index 38a47592..309972d2 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan2.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan2.py @@ -44,8 +44,9 @@ except ImportError: "accelerate training use the following command to install Xformers\npip install xformers." ) +import os -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def baichuan_13b_rms_norm_forward(self, hidden_states): diff --git a/python/llm/src/ipex_llm/transformers/models/bloom.py b/python/llm/src/ipex_llm/transformers/models/bloom.py index 46489e8b..5c2e658a 100644 --- a/python/llm/src/ipex_llm/transformers/models/bloom.py +++ b/python/llm/src/ipex_llm/transformers/models/bloom.py @@ -40,8 +40,9 @@ from torch.nn import functional as F from ipex_llm.transformers.models.utils import use_fused_layer_norm from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +import os -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool): diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm.py b/python/llm/src/ipex_llm/transformers/models/chatglm.py index ac9a98a1..0cd1cc94 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm.py @@ -38,7 +38,9 @@ def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) return q, k -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +import os + +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_MIN_LENGTH = 512 diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 3d69cd18..196473ae 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -28,7 +28,9 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_ from ipex_llm.transformers.models.utils import use_esimd_sdp -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +import os + +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_MIN_LENGTH = 512 diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2_32k.py b/python/llm/src/ipex_llm/transformers/models/chatglm2_32k.py index 94856152..38357e44 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2_32k.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2_32k.py @@ -23,7 +23,9 @@ import torch.nn.functional as F from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +import os + +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_MIN_LENGTH = 512 diff --git a/python/llm/src/ipex_llm/transformers/models/decilm.py b/python/llm/src/ipex_llm/transformers/models/decilm.py index 67bc5e49..771cf8b9 100644 --- a/python/llm/src/ipex_llm/transformers/models/decilm.py +++ b/python/llm/src/ipex_llm/transformers/models/decilm.py @@ -41,7 +41,9 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from ipex_llm.transformers.models.llama import should_use_fuse_rope, repeat_kv from ipex_llm.utils.common import invalidInputError -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +import os + +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def decilm_attention_forward_4_35_2( diff --git a/python/llm/src/ipex_llm/transformers/models/falcon.py b/python/llm/src/ipex_llm/transformers/models/falcon.py index 4932aeab..14d08d09 100644 --- a/python/llm/src/ipex_llm/transformers/models/falcon.py +++ b/python/llm/src/ipex_llm/transformers/models/falcon.py @@ -41,8 +41,9 @@ from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache import warnings +import os -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) # Copied from transformers.models.llama.modeling_llama.rotate_half diff --git a/python/llm/src/ipex_llm/transformers/models/gemma.py b/python/llm/src/ipex_llm/transformers/models/gemma.py index f6bf2db5..4eb6f5fe 100644 --- a/python/llm/src/ipex_llm/transformers/models/gemma.py +++ b/python/llm/src/ipex_llm/transformers/models/gemma.py @@ -43,7 +43,9 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rot from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5 from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +import os + +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): diff --git a/python/llm/src/ipex_llm/transformers/models/gptj.py b/python/llm/src/ipex_llm/transformers/models/gptj.py index 38df3cb1..71bd4f7d 100644 --- a/python/llm/src/ipex_llm/transformers/models/gptj.py +++ b/python/llm/src/ipex_llm/transformers/models/gptj.py @@ -26,8 +26,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.gptj.modeling_gptj import GPTJModel from ipex_llm.utils.common import invalidInputError +import os -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def _get_embed_positions(self, position_ids): diff --git a/python/llm/src/ipex_llm/transformers/models/gptneox.py b/python/llm/src/ipex_llm/transformers/models/gptneox.py index 52466042..4e0129c9 100644 --- a/python/llm/src/ipex_llm/transformers/models/gptneox.py +++ b/python/llm/src/ipex_llm/transformers/models/gptneox.py @@ -38,8 +38,9 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ append_kv_cache, is_enough_kv_cache_room_4_31 from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu +import os -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def gptneox_attention_forward( diff --git a/python/llm/src/ipex_llm/transformers/models/internlm.py b/python/llm/src/ipex_llm/transformers/models/internlm.py index 038a63d8..fe9f708c 100644 --- a/python/llm/src/ipex_llm/transformers/models/internlm.py +++ b/python/llm/src/ipex_llm/transformers/models/internlm.py @@ -48,8 +48,9 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu +import os -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def internlm_attention_forward( diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 4ebe4886..c81cafff 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -60,7 +60,10 @@ try: from transformers.cache_utils import Cache except ImportError: Cache = Tuple[torch.Tensor] -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 + +import os + +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: diff --git a/python/llm/src/ipex_llm/transformers/models/mixtral.py b/python/llm/src/ipex_llm/transformers/models/mixtral.py index c25e1425..9bf3af14 100644 --- a/python/llm/src/ipex_llm/transformers/models/mixtral.py +++ b/python/llm/src/ipex_llm/transformers/models/mixtral.py @@ -58,8 +58,9 @@ from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sd from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU from ipex_llm.transformers.low_bit_linear import IQ2_XXS +import os -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: diff --git a/python/llm/src/ipex_llm/transformers/models/mpt.py b/python/llm/src/ipex_llm/transformers/models/mpt.py index 4d4a191a..f6603d73 100644 --- a/python/llm/src/ipex_llm/transformers/models/mpt.py +++ b/python/llm/src/ipex_llm/transformers/models/mpt.py @@ -25,8 +25,9 @@ import torch.nn.functional as F from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache +import os -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None, diff --git a/python/llm/src/ipex_llm/transformers/models/phixtral.py b/python/llm/src/ipex_llm/transformers/models/phixtral.py index 66595d5c..8feaabe8 100644 --- a/python/llm/src/ipex_llm/transformers/models/phixtral.py +++ b/python/llm/src/ipex_llm/transformers/models/phixtral.py @@ -52,8 +52,9 @@ from ipex_llm.transformers.models.mistral import should_use_fuse_rope, use_decod from ipex_llm.transformers.models.utils import use_flash_attention from ipex_llm.transformers.models.utils import mlp_fusion_check +import os -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: diff --git a/python/llm/src/ipex_llm/transformers/models/qwen.py b/python/llm/src/ipex_llm/transformers/models/qwen.py index 88ab88f3..271607ef 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen.py @@ -54,7 +54,9 @@ flash_attn_unpadded_func = None logger = logging.get_logger(__name__) -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +import os + +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 66f86692..2369c5a7 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -69,7 +69,9 @@ from transformers import logging logger = logging.get_logger(__name__) -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +import os + +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def should_use_fuse_rope(self, query_states, position_ids): diff --git a/python/llm/src/ipex_llm/transformers/models/qwen_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen_vl.py index 4e33423b..cfc390b7 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen_vl.py @@ -35,7 +35,9 @@ from ipex_llm.transformers.models.utils import rotate_half from ipex_llm.transformers.models.utils import use_esimd_sdp from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +import os + +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def apply_rotary_pos_emb(t, freqs): diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index 53372e81..a6cd1bfb 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -60,8 +60,9 @@ try: except ImportError: Cache = Tuple[torch.Tensor] +import os -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def merge_qkv(module: torch.nn.Module): diff --git a/python/llm/src/ipex_llm/transformers/models/yuan.py b/python/llm/src/ipex_llm/transformers/models/yuan.py index b80ab209..43f86732 100644 --- a/python/llm/src/ipex_llm/transformers/models/yuan.py +++ b/python/llm/src/ipex_llm/transformers/models/yuan.py @@ -38,7 +38,9 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, SIL from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5 from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +import os + +KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):