remove some useless code (#12035)
This commit is contained in:
parent
d2e1b9aaff
commit
6cedb601e4
3 changed files with 7 additions and 20 deletions
|
|
@ -47,8 +47,7 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
|
||||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||||
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
|
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \
|
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
|
||||||
use_sdp_causal
|
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
|
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
|
||||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path, get_q_proj_or_qkv_proj
|
from ipex_llm.transformers.models.utils import use_decoding_fast_path, get_q_proj_or_qkv_proj
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
|
@ -599,7 +598,7 @@ def llama_attention_forward_4_31_quantized(
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
past_key_value = (key_states, value_states)
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
if not use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
query_states.dtype)
|
query_states.dtype)
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
|
@ -1282,7 +1281,7 @@ def llama_attention_forward_4_41_quantized(
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, cache_kwargs)
|
self.layer_idx, cache_kwargs)
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
if not use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
query_states.dtype)
|
query_states.dtype)
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)\
|
key_states = repeat_kv(key_states, self.num_key_value_groups)\
|
||||||
|
|
@ -1873,7 +1872,7 @@ def llama_attention_forward_4_38_quantized(
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, cache_kwargs)
|
self.layer_idx, cache_kwargs)
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
if not use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
query_states.dtype)
|
query_states.dtype)
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)\
|
key_states = repeat_kv(key_states, self.num_key_value_groups)\
|
||||||
|
|
|
||||||
|
|
@ -51,9 +51,7 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||||
is_enough_kv_cache_room_4_36
|
is_enough_kv_cache_room_4_36
|
||||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
|
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
|
||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \
|
|
||||||
use_sdp_causal
|
|
||||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||||
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
|
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
|
||||||
from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv
|
from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv
|
||||||
|
|
@ -409,7 +407,7 @@ def mistral_attention_forward_quantized(
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
past_key_value = (key_states, value_states)
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
if not use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
query_states.dtype)
|
query_states.dtype)
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||||
|
|
@ -845,7 +843,7 @@ def mistral_attention_forward_4_36_quantized(
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, cache_kwargs)
|
self.layer_idx, cache_kwargs)
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
if not use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
query_states.dtype)
|
query_states.dtype)
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type
|
from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type
|
||||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4,\
|
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4,\
|
||||||
FP6, ASYM_INT4
|
FP6, ASYM_INT4
|
||||||
from ipex_llm.transformers.convert import is_deepspeed_available
|
|
||||||
|
|
||||||
FP8_KV_ALLOC_LENGTH = 512
|
FP8_KV_ALLOC_LENGTH = 512
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||||
|
|
@ -335,15 +334,6 @@ def use_sdp(q_len, kv_len, head_dim, query_states):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def use_sdp_fp8(q_len, kv_len, query_states):
|
|
||||||
return (
|
|
||||||
query_states.device.type == "xpu"
|
|
||||||
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
|
|
||||||
and q_len != kv_len # next token
|
|
||||||
and q_len <= 32 # lookup
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
|
def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
|
||||||
return (
|
return (
|
||||||
q_len == kv_len # first token
|
q_len == kv_len # first token
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue