use new sdp and fp32 sdp (#11007)
This commit is contained in:
parent
8010af700f
commit
170e3d65e0
11 changed files with 62 additions and 62 deletions
|
|
@ -27,7 +27,7 @@ from torch import nn
|
|||
import torch.nn.functional as F
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
|
||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||
|
|
@ -276,11 +276,9 @@ def baichuan_attention_forward_7b_origin(
|
|||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||
key_states,
|
||||
value_states)
|
||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
|
|||
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
|
||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
|
||||
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache
|
||||
|
|
@ -420,9 +420,13 @@ def cohere_attention_forward_origin(
|
|||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
else:
|
||||
causal_mask = None
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, causal_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -46,8 +46,7 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
|
|||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
|
||||
use_sdp_fp8
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
|
||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
|
|
@ -673,9 +672,9 @@ def llama_attention_forward_4_31_original(
|
|||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
else:
|
||||
|
|
@ -1348,9 +1347,9 @@ def llama_attention_forward_4_36_original(
|
|||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -52,8 +52,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
|||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||
is_enough_kv_cache_room_4_36
|
||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
|
||||
use_sdp_fp8
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
|
||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
|
||||
from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv
|
||||
|
|
@ -591,10 +590,10 @@ def mistral_attention_forward_original(
|
|||
attn_weights = None
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
elif use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
# new fp16 sdp doesn't require repeat_kv
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
|
@ -1032,10 +1031,10 @@ def mistral_attention_forward_4_36_original(
|
|||
attn_weights = None
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
elif use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
# new fp16 sdp doesn't require repeat_kv
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb,\
|
|||
apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36
|
||||
from ipex_llm.transformers.models.mistral import should_use_fuse_rope
|
||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||
from ipex_llm.transformers.low_bit_linear import IQ2_XXS
|
||||
|
||||
|
|
@ -332,12 +332,9 @@ def mixtral_attention_forward(
|
|||
value_states,
|
||||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif use_esimd_sdp(query_states.shape[2], key_states.shape[2],
|
||||
self.head_dim, query_states):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||
key_states,
|
||||
value_states)
|
||||
elif use_sdp(query_states.shape[2], key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -40,9 +40,8 @@ from ipex_llm.transformers.models.utils import (
|
|||
apply_rotary_pos_emb_cache_freq_xpu
|
||||
)
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||
from ipex_llm.transformers.models.utils import use_new_esimd_sdp_fp16, use_quantize_kv_cache
|
||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_quantize_kv_cache
|
||||
from ipex_llm.transformers.models.utils import use_sdp_fp8, restore_fp8_kv_cache
|
||||
from ipex_llm.transformers.models.utils import use_sdp_causal
|
||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
||||
|
||||
from typing import Optional, Tuple, List
|
||||
|
|
@ -142,9 +141,9 @@ def attention_forward(
|
|||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, attention_mask)
|
||||
elif (isinstance(past_key_value, DynamicNormalCache) and
|
||||
use_new_esimd_sdp_fp16(q_len, kv_seq_len, self.head_dim, query_states)):
|
||||
use_sdp(q_len, kv_seq_len, self.head_dim, query_states)):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
||||
else:
|
||||
if isinstance(past_key_value, DynamicFp8Cache):
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
|
|
|
|||
|
|
@ -42,8 +42,7 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
|
|||
from ipex_llm.transformers.models.utils import rotate_half, SILU
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
|
||||
use_sdp_fp8
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
|
||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||
from ipex_llm.utils.common import invalidInputError, invalidOperationError
|
||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||
|
|
@ -291,9 +290,9 @@ def qwen_attention_forward_original(
|
|||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_new_esimd_sdp_fp16(q_len, key.shape[2], self.head_dim, query):
|
||||
use_sdp(q_len, key.shape[2], self.head_dim, query):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query, key, value, attention_mask)
|
||||
attn_output = linear_q4_0.sdp(query, key, value, attention_mask)
|
||||
attn_output = attn_output.view(query.shape)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_weight = None
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
|||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
|
||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
|
||||
|
|
@ -565,11 +565,9 @@ def qwen2_attention_forward_origin(
|
|||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||
key_states,
|
||||
value_states)
|
||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ import torch.utils.checkpoint
|
|||
from transformers.utils import logging
|
||||
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
||||
from ipex_llm.transformers.models.utils import rotate_half
|
||||
from ipex_llm.transformers.models.utils import use_esimd_sdp
|
||||
from ipex_llm.transformers.models.utils import use_sdp
|
||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||
|
||||
import os
|
||||
|
|
@ -207,11 +207,9 @@ def qwen_attention_forward_vl(
|
|||
query = query.permute(0, 2, 1, 3)
|
||||
|
||||
if not self.training and not hidden_states.requires_grad and \
|
||||
use_esimd_sdp(q_len, key.shape[2], self.head_dim, query):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query,
|
||||
key,
|
||||
value)
|
||||
use_sdp(q_len, key.shape[2], self.head_dim, query):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp(query, key, value, attention_mask)
|
||||
attn_output = attn_output.view(query.shape)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_weight = None
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
|||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
|
||||
from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv
|
||||
try:
|
||||
from transformers.cache_utils import Cache
|
||||
|
|
@ -266,11 +266,9 @@ def stablelm_attention_forward_original(
|
|||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states, attention_mask):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||
key_states,
|
||||
value_states)
|
||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -375,20 +375,31 @@ def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
|
|||
return True
|
||||
|
||||
|
||||
def use_sdp_fp8(q_len, k_len, query_states):
|
||||
if query_states.device.type != "xpu":
|
||||
return False
|
||||
if q_len == k_len:
|
||||
# sdp_fp8 only support rest token now
|
||||
return False
|
||||
return True
|
||||
def use_sdp(q_len, kv_len, head_dim, query_states):
|
||||
return (
|
||||
query_states.device.type == "xpu"
|
||||
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
|
||||
and head_dim in [64, 96, 128]
|
||||
and q_len != kv_len # next token
|
||||
and q_len <= 32 # lookup
|
||||
)
|
||||
|
||||
|
||||
def use_sdp_fp8(q_len, kv_len, query_states):
|
||||
return (
|
||||
query_states.device.type == "xpu"
|
||||
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
|
||||
and q_len != kv_len # next token
|
||||
and q_len <= 32 # lookup
|
||||
)
|
||||
|
||||
|
||||
def use_sdp_causal(q_len, kv_len, query_states, training):
|
||||
return (
|
||||
q_len == kv_len # first token
|
||||
and query_states.device.type == "xpu" # GPU
|
||||
and not query_states.requires_grad and not training # no training
|
||||
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
|
||||
and not query_states.requires_grad and not training # not training
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue