Read the value of KV_CACHE_ALLOC_BLOCK_LENGTH from the environment variables (#10707)

* Read the value of KV_CACHE_ALLOC_BLOCK_LENGTH from the environment variables.

* Fix style
This commit is contained in:
Keyan (Kyrie) Zhang 2024-04-09 19:48:46 -07:00 committed by GitHub
parent d1eaea509f
commit 585c174e92
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 57 additions and 22 deletions

View file

@ -48,7 +48,9 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.utils.common import log4Error
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def aquila_attention_forward(

View file

@ -35,7 +35,9 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
from ipex_llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def baichuan_attention_forward_7b(

View file

@ -44,8 +44,9 @@ except ImportError:
"accelerate training use the following command to install Xformers\npip install xformers."
)
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def baichuan_13b_rms_norm_forward(self, hidden_states):

View file

@ -40,8 +40,9 @@ from torch.nn import functional as F
from ipex_llm.transformers.models.utils import use_fused_layer_norm
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool):

View file

@ -38,7 +38,9 @@ def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
return q, k
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
KV_CACHE_ALLOC_MIN_LENGTH = 512

View file

@ -28,7 +28,9 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
from ipex_llm.transformers.models.utils import use_esimd_sdp
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
KV_CACHE_ALLOC_MIN_LENGTH = 512

View file

@ -23,7 +23,9 @@ import torch.nn.functional as F
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
KV_CACHE_ALLOC_MIN_LENGTH = 512

View file

@ -41,7 +41,9 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.transformers.models.llama import should_use_fuse_rope, repeat_kv
from ipex_llm.utils.common import invalidInputError
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def decilm_attention_forward_4_35_2(

View file

@ -41,8 +41,9 @@ from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
import warnings
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
# Copied from transformers.models.llama.modeling_llama.rotate_half

View file

@ -43,7 +43,9 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rot
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):

View file

@ -26,8 +26,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.gptj.modeling_gptj import GPTJModel
from ipex_llm.utils.common import invalidInputError
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def _get_embed_positions(self, position_ids):

View file

@ -38,8 +38,9 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def gptneox_attention_forward(

View file

@ -48,8 +48,9 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def internlm_attention_forward(

View file

@ -60,7 +60,10 @@ try:
from transformers.cache_utils import Cache
except ImportError:
Cache = Tuple[torch.Tensor]
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

View file

@ -58,8 +58,9 @@ from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sd
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
from ipex_llm.transformers.low_bit_linear import IQ2_XXS
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

View file

@ -25,8 +25,9 @@ import torch.nn.functional as F
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,

View file

@ -52,8 +52,9 @@ from ipex_llm.transformers.models.mistral import should_use_fuse_rope, use_decod
from ipex_llm.transformers.models.utils import use_flash_attention
from ipex_llm.transformers.models.utils import mlp_fusion_check
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

View file

@ -54,7 +54,9 @@ flash_attn_unpadded_func = None
logger = logging.get_logger(__name__)
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2

View file

@ -69,7 +69,9 @@ from transformers import logging
logger = logging.get_logger(__name__)
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def should_use_fuse_rope(self, query_states, position_ids):

View file

@ -35,7 +35,9 @@ from ipex_llm.transformers.models.utils import rotate_half
from ipex_llm.transformers.models.utils import use_esimd_sdp
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def apply_rotary_pos_emb(t, freqs):

View file

@ -60,8 +60,9 @@ try:
except ImportError:
Cache = Tuple[torch.Tensor]
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def merge_qkv(module: torch.nn.Module):

View file

@ -38,7 +38,9 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, SIL
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):