remove unused code (#12635)

This commit is contained in:
Yishuo Wang 2025-01-02 13:31:09 +08:00 committed by GitHub
parent 534566e290
commit 81211fd010
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 47 additions and 79 deletions

View file

@ -29,7 +29,7 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp
should_use_compresskv should_use_compresskv
from ipex_llm.transformers.models.utils import update_past_key_value from ipex_llm.transformers.models.utils import update_past_key_value
from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp from ipex_llm.transformers.models.utils import use_sdp
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
from ipex_llm.transformers.models.utils import mlp_fusion_check from ipex_llm.transformers.models.utils import mlp_fusion_check
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
@ -301,12 +301,6 @@ def baichuan_attention_forward_7b(
# IPEX-LLM OPT: sdp # IPEX-LLM OPT: sdp
attn_weights = None attn_weights = None
if use_flash_attention(query_states, key_states, attention_mask):
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
key_states.to(dtype=torch.float16),
value_states.to(dtype=torch.float16),
is_causal=True).to(hidden_states.dtype)
else:
attn_output = scaled_dot_product_attention( attn_output = scaled_dot_product_attention(
query_states, key_states, value_states, query_states, key_states, value_states,
attention_mask, q_len == kv_seq_len attention_mask, q_len == kv_seq_len

View file

@ -23,7 +23,7 @@ import torch.utils.checkpoint
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, Tuple from typing import Optional, Tuple
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp from ipex_llm.transformers.models.utils import use_sdp
def rotate_half(x): def rotate_half(x):
@ -41,7 +41,7 @@ def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
def glm_sdpa(query, key, value, attention_mask=None, is_causal=False): def glm_sdpa(query, key, value, attention_mask=None, is_causal=False):
if use_flash_attention(query, key, attention_mask) or query.device.type == 'cpu': if query.device.type == 'cpu':
context_layer = F.scaled_dot_product_attention(query.to(key.dtype), context_layer = F.scaled_dot_product_attention(query.to(key.dtype),
key, key,
value, value,

View file

@ -33,7 +33,6 @@ from ipex_llm.transformers.models.utils import update_past_key_value, should_use
from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.models.utils import rotate_half, SILU from ipex_llm.transformers.models.utils import rotate_half, SILU
from ipex_llm.transformers.models.utils import mlp_fusion_check from ipex_llm.transformers.models.utils import mlp_fusion_check
from ipex_llm.transformers.models.utils import use_flash_attention
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
@ -116,14 +115,9 @@ def qwen_attention_forward(
past_key_value = (key_states.transpose(1, 2), past_key_value = (key_states.transpose(1, 2),
value_states.transpose(1, 2)) if use_cache else None value_states.transpose(1, 2)) if use_cache else None
# IPEX-LLM OPT: sdp # IPEX-LLM OPT: sdpa
attn_weights = None attn_weights = None
if use_flash_attention(query_states, key_states, attention_mask):
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
key_states.to(dtype=torch.float16),
value_states.to(dtype=torch.float16),
is_causal=True).to(hidden_states.dtype)
else:
if q_len > 1 and q_len != kv_seq_len: if q_len > 1 and q_len != kv_seq_len:
causal_mask = torch.tril( causal_mask = torch.tril(
torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query_states.device) torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query_states.device)
@ -219,15 +213,9 @@ def qwen_attention_forward_registered(
past_key_value = (key_states.transpose(1, 2), past_key_value = (key_states.transpose(1, 2),
value_states.transpose(1, 2)) if use_cache else None value_states.transpose(1, 2)) if use_cache else None
# IPEX-LLM OPT: sdp # IPEX-LLM OPT: sdpa
attn_weights = None attn_weights = None
if use_flash_attention(query_states, key_states, attention_mask):
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
key_states.to(dtype=torch.float16),
value_states.to(dtype=torch.float16),
is_causal=True).to(hidden_states.dtype)
else:
if q_len > 1 and q_len != kv_seq_len: if q_len > 1 and q_len != kv_seq_len:
causal_mask = registered_causal_mask[ causal_mask = registered_causal_mask[
:, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len :, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len

View file

@ -38,12 +38,10 @@
# #
import os import os
import math
from typing import Optional, Tuple, Union, List from typing import Optional, Tuple, Union, List
import torch import torch
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.nn.functional import scaled_dot_product_attention as sdpa
from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.models.common import scaled_dot_product_attention
@ -51,13 +49,12 @@ from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, \ from ipex_llm.transformers.models.utils import use_quantize_kv_cache, \
should_use_compresskv, is_enough_kv_cache_room_4_36 should_use_compresskv, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import use_flash_attention
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, \ from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, \
DynamicCompressCache, DynamicCompressFp8Cache DynamicCompressCache, DynamicCompressFp8Cache
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers import logging from transformers import logging
@ -580,17 +577,6 @@ def qwen2_attention_forward(
self.layer_idx, None) self.layer_idx, None)
attn_weights = None attn_weights = None
if use_flash_attention(query_states, key_states, attention_mask):
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, :kv_seq_len]
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output = sdpa(query_states.to(device, dtype=torch.float16),
key_states.to(device, dtype=torch.float16),
value_states.to(device, dtype=torch.float16),
is_causal=True).to(hidden_states.dtype)
else:
attn_output = scaled_dot_product_attention( attn_output = scaled_dot_product_attention(
query_states, key_states, value_states, query_states, key_states, value_states,
attention_mask, q_len == kv_seq_len attention_mask, q_len == kv_seq_len