LLM: fix get env KV_CACHE_ALLOC_BLOCK_LENGTH type. (#10771)

This commit is contained in:
Cengguang Zhang 2024-04-16 09:32:30 +08:00 committed by GitHub
parent 7297036c03
commit 3e2662c87e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 23 additions and 23 deletions

View file

@ -50,7 +50,7 @@ from ipex_llm.utils.common import log4Error
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def aquila_attention_forward( def aquila_attention_forward(

View file

@ -37,7 +37,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def baichuan_attention_forward_7b( def baichuan_attention_forward_7b(

View file

@ -46,7 +46,7 @@ except ImportError:
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def baichuan_13b_rms_norm_forward(self, hidden_states): def baichuan_13b_rms_norm_forward(self, hidden_states):

View file

@ -42,7 +42,7 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, a
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool): def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool):

View file

@ -40,7 +40,7 @@ def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
KV_CACHE_ALLOC_MIN_LENGTH = 512 KV_CACHE_ALLOC_MIN_LENGTH = 512

View file

@ -30,7 +30,7 @@ from ipex_llm.transformers.models.utils import use_esimd_sdp
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
KV_CACHE_ALLOC_MIN_LENGTH = 512 KV_CACHE_ALLOC_MIN_LENGTH = 512

View file

@ -25,7 +25,7 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, a
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
KV_CACHE_ALLOC_MIN_LENGTH = 512 KV_CACHE_ALLOC_MIN_LENGTH = 512

View file

@ -43,7 +43,7 @@ from ipex_llm.utils.common import invalidInputError
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def decilm_attention_forward_4_35_2( def decilm_attention_forward_4_35_2(

View file

@ -43,7 +43,7 @@ import warnings
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
# Copied from transformers.models.llama.modeling_llama.rotate_half # Copied from transformers.models.llama.modeling_llama.rotate_half

View file

@ -45,7 +45,7 @@ from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):

View file

@ -28,7 +28,7 @@ from ipex_llm.utils.common import invalidInputError
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def _get_embed_positions(self, position_ids): def _get_embed_positions(self, position_ids):

View file

@ -40,7 +40,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def gptneox_attention_forward( def gptneox_attention_forward(

View file

@ -50,7 +50,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def internlm_attention_forward( def internlm_attention_forward(

View file

@ -83,7 +83,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
n_rep, slen, head_dim) n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
_ipex_version = None _ipex_version = None

View file

@ -63,7 +63,7 @@ except ImportError:
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

View file

@ -60,7 +60,7 @@ from ipex_llm.transformers.low_bit_linear import IQ2_XXS
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

View file

@ -27,7 +27,7 @@ from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, a
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None, def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,

View file

@ -54,7 +54,7 @@ from ipex_llm.transformers.models.utils import mlp_fusion_check
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

View file

@ -56,7 +56,7 @@ logger = logging.get_logger(__name__)
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2

View file

@ -71,7 +71,7 @@ logger = logging.get_logger(__name__)
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def should_use_fuse_rope(self, query_states, position_ids): def should_use_fuse_rope(self, query_states, position_ids):

View file

@ -37,7 +37,7 @@ from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def apply_rotary_pos_emb(t, freqs): def apply_rotary_pos_emb(t, freqs):

View file

@ -62,7 +62,7 @@ except ImportError:
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def merge_qkv(module: torch.nn.Module): def merge_qkv(module: torch.nn.Module):

View file

@ -40,7 +40,7 @@ from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs): def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):