Read the value of KV_CACHE_ALLOC_BLOCK_LENGTH from the environment variables (#10707)
* Read the value of KV_CACHE_ALLOC_BLOCK_LENGTH from the environment variables. * Fix style
This commit is contained in:
parent
d1eaea509f
commit
585c174e92
22 changed files with 57 additions and 22 deletions
|
|
@ -48,7 +48,9 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
from ipex_llm.utils.common import log4Error
|
from ipex_llm.utils.common import log4Error
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
import os
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def aquila_attention_forward(
|
def aquila_attention_forward(
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,9 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
|
||||||
from ipex_llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
from ipex_llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
import os
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def baichuan_attention_forward_7b(
|
def baichuan_attention_forward_7b(
|
||||||
|
|
|
||||||
|
|
@ -44,8 +44,9 @@ except ImportError:
|
||||||
"accelerate training use the following command to install Xformers\npip install xformers."
|
"accelerate training use the following command to install Xformers\npip install xformers."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def baichuan_13b_rms_norm_forward(self, hidden_states):
|
def baichuan_13b_rms_norm_forward(self, hidden_states):
|
||||||
|
|
|
||||||
|
|
@ -40,8 +40,9 @@ from torch.nn import functional as F
|
||||||
from ipex_llm.transformers.models.utils import use_fused_layer_norm
|
from ipex_llm.transformers.models.utils import use_fused_layer_norm
|
||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool):
|
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool):
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,9 @@ def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
|
||||||
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
||||||
return q, k
|
return q, k
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
import os
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
KV_CACHE_ALLOC_MIN_LENGTH = 512
|
KV_CACHE_ALLOC_MIN_LENGTH = 512
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,9 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
|
||||||
from ipex_llm.transformers.models.utils import use_esimd_sdp
|
from ipex_llm.transformers.models.utils import use_esimd_sdp
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
import os
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
KV_CACHE_ALLOC_MIN_LENGTH = 512
|
KV_CACHE_ALLOC_MIN_LENGTH = 512
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,9 @@ import torch.nn.functional as F
|
||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
import os
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
KV_CACHE_ALLOC_MIN_LENGTH = 512
|
KV_CACHE_ALLOC_MIN_LENGTH = 512
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,9 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
from ipex_llm.transformers.models.llama import should_use_fuse_rope, repeat_kv
|
from ipex_llm.transformers.models.llama import should_use_fuse_rope, repeat_kv
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
import os
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def decilm_attention_forward_4_35_2(
|
def decilm_attention_forward_4_35_2(
|
||||||
|
|
|
||||||
|
|
@ -41,8 +41,9 @@ from ipex_llm.utils.common import invalidInputError
|
||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,9 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rot
|
||||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
||||||
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
|
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
import os
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.gptj.modeling_gptj import GPTJModel
|
from transformers.models.gptj.modeling_gptj import GPTJModel
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def _get_embed_positions(self, position_ids):
|
def _get_embed_positions(self, position_ids):
|
||||||
|
|
|
||||||
|
|
@ -38,8 +38,9 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def gptneox_attention_forward(
|
def gptneox_attention_forward(
|
||||||
|
|
|
||||||
|
|
@ -48,8 +48,9 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def internlm_attention_forward(
|
def internlm_attention_forward(
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,10 @@ try:
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Cache = Tuple[torch.Tensor]
|
Cache = Tuple[torch.Tensor]
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
|
|
||||||
|
|
@ -58,8 +58,9 @@ from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sd
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||||
from ipex_llm.transformers.low_bit_linear import IQ2_XXS
|
from ipex_llm.transformers.low_bit_linear import IQ2_XXS
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,9 @@ import torch.nn.functional as F
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
|
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
|
||||||
|
|
|
||||||
|
|
@ -52,8 +52,9 @@ from ipex_llm.transformers.models.mistral import should_use_fuse_rope, use_decod
|
||||||
from ipex_llm.transformers.models.utils import use_flash_attention
|
from ipex_llm.transformers.models.utils import use_flash_attention
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,9 @@ flash_attn_unpadded_func = None
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
import os
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
|
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,9 @@ from transformers import logging
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
import os
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def should_use_fuse_rope(self, query_states, position_ids):
|
def should_use_fuse_rope(self, query_states, position_ids):
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,9 @@ from ipex_llm.transformers.models.utils import rotate_half
|
||||||
from ipex_llm.transformers.models.utils import use_esimd_sdp
|
from ipex_llm.transformers.models.utils import use_esimd_sdp
|
||||||
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
|
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
import os
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(t, freqs):
|
def apply_rotary_pos_emb(t, freqs):
|
||||||
|
|
|
||||||
|
|
@ -60,8 +60,9 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Cache = Tuple[torch.Tensor]
|
Cache = Tuple[torch.Tensor]
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,9 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, SIL
|
||||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
||||||
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
|
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
import os
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
|
||||||
|
|
||||||
|
|
||||||
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
|
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue