Read the value of KV_CACHE_ALLOC_BLOCK_LENGTH from the environment variables (#10707)
* Read the value of KV_CACHE_ALLOC_BLOCK_LENGTH from the environment variables. * Fix style
This commit is contained in:
		
							parent
							
								
									d1eaea509f
								
							
						
					
					
						commit
						585c174e92
					
				
					 22 changed files with 57 additions and 22 deletions
				
			
		| 
						 | 
				
			
			@ -48,7 +48,9 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
 | 
			
		|||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from ipex_llm.utils.common import log4Error
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def aquila_attention_forward(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -35,7 +35,9 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
 | 
			
		|||
from ipex_llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_attention_forward_7b(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -44,8 +44,9 @@ except ImportError:
 | 
			
		|||
        "accelerate training use the following command to install Xformers\npip install xformers."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_13b_rms_norm_forward(self, hidden_states):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -40,8 +40,9 @@ from torch.nn import functional as F
 | 
			
		|||
from ipex_llm.transformers.models.utils import use_fused_layer_norm
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,7 +38,9 @@ def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
 | 
			
		|||
    q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
 | 
			
		||||
    return q, k
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
KV_CACHE_ALLOC_MIN_LENGTH = 512
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,7 +28,9 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
 | 
			
		|||
from ipex_llm.transformers.models.utils import use_esimd_sdp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
KV_CACHE_ALLOC_MIN_LENGTH = 512
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -23,7 +23,9 @@ import torch.nn.functional as F
 | 
			
		|||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
KV_CACHE_ALLOC_MIN_LENGTH = 512
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -41,7 +41,9 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		|||
from ipex_llm.transformers.models.llama import should_use_fuse_rope, repeat_kv
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def decilm_attention_forward_4_35_2(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -41,8 +41,9 @@ from ipex_llm.utils.common import invalidInputError
 | 
			
		|||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,7 +43,9 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rot
 | 
			
		|||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
 | 
			
		||||
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,8 +26,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		|||
from transformers.models.gptj.modeling_gptj import GPTJModel
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_embed_positions(self, position_ids):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,8 +38,9 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
			
		|||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def gptneox_attention_forward(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -48,8 +48,9 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
			
		|||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def internlm_attention_forward(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -60,7 +60,10 @@ try:
 | 
			
		|||
    from transformers.cache_utils import Cache
 | 
			
		||||
except ImportError:
 | 
			
		||||
    Cache = Tuple[torch.Tensor]
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -58,8 +58,9 @@ from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sd
 | 
			
		|||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
 | 
			
		||||
from ipex_llm.transformers.low_bit_linear import IQ2_XXS
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,8 +25,9 @@ import torch.nn.functional as F
 | 
			
		|||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -52,8 +52,9 @@ from ipex_llm.transformers.models.mistral import should_use_fuse_rope, use_decod
 | 
			
		|||
from ipex_llm.transformers.models.utils import use_flash_attention
 | 
			
		||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -54,7 +54,9 @@ flash_attn_unpadded_func = None
 | 
			
		|||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -69,7 +69,9 @@ from transformers import logging
 | 
			
		|||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_use_fuse_rope(self, query_states, position_ids):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -35,7 +35,9 @@ from ipex_llm.transformers.models.utils import rotate_half
 | 
			
		|||
from ipex_llm.transformers.models.utils import use_esimd_sdp
 | 
			
		||||
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_rotary_pos_emb(t, freqs):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -60,8 +60,9 @@ try:
 | 
			
		|||
except ImportError:
 | 
			
		||||
    Cache = Tuple[torch.Tensor]
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,7 +38,9 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, SIL
 | 
			
		|||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
 | 
			
		||||
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue