use new sdp and fp32 sdp (#11007)

This commit is contained in:
Yishuo Wang 2024-05-14 14:29:18 +08:00 committed by GitHub
parent 8010af700f
commit 170e3d65e0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 62 additions and 62 deletions

View file

@ -27,7 +27,7 @@ from torch import nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
@ -276,11 +276,9 @@ def baichuan_attention_forward_7b_origin(
is_causal=True)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states,
value_states)
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_q4_0
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
else:

View file

@ -50,7 +50,7 @@ from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import use_decoding_fast_path
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.kv import DynamicFp8Cache
@ -420,9 +420,13 @@ def cohere_attention_forward_origin(
is_causal=True)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_q4_0
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
else:
causal_mask = None
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, causal_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
else:

View file

@ -46,8 +46,7 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
use_sdp_fp8
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
from ipex_llm.transformers.models.utils import use_decoding_fast_path
from transformers.modeling_outputs import BaseModelOutputWithPast
@ -673,9 +672,9 @@ def llama_attention_forward_4_31_original(
is_causal=True)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_q4_0
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
else:
@ -1348,9 +1347,9 @@ def llama_attention_forward_4_36_original(
is_causal=True)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_q4_0
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
else:

View file

@ -52,8 +52,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
is_enough_kv_cache_room_4_36
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
use_sdp_fp8
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
from ipex_llm.transformers.models.utils import use_decoding_fast_path
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv
@ -591,10 +590,10 @@ def mistral_attention_forward_original(
attn_weights = None
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
elif use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
# new fp16 sdp doesn't require repeat_kv
import linear_q4_0
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
attn_output = attn_output.transpose(1, 2).contiguous()
@ -1032,10 +1031,10 @@ def mistral_attention_forward_4_36_original(
attn_weights = None
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
elif use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
# new fp16 sdp doesn't require repeat_kv
import linear_q4_0
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
attn_output = attn_output.transpose(1, 2).contiguous()

View file

@ -55,7 +55,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb,\
apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.mistral import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_decoding_fast_path
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
from ipex_llm.transformers.low_bit_linear import IQ2_XXS
@ -332,12 +332,9 @@ def mixtral_attention_forward(
value_states,
is_causal=True)
attn_weights = None
elif use_esimd_sdp(query_states.shape[2], key_states.shape[2],
self.head_dim, query_states):
import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states,
value_states)
elif use_sdp(query_states.shape[2], key_states.shape[2], self.head_dim, query_states):
import linear_q4_0
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
else:

View file

@ -40,9 +40,8 @@ from ipex_llm.transformers.models.utils import (
apply_rotary_pos_emb_cache_freq_xpu
)
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
from ipex_llm.transformers.models.utils import use_new_esimd_sdp_fp16, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import use_sdp_fp8, restore_fp8_kv_cache
from ipex_llm.transformers.models.utils import use_sdp_causal
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
from typing import Optional, Tuple, List
@ -142,9 +141,9 @@ def attention_forward(
import linear_q4_0
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, attention_mask)
elif (isinstance(past_key_value, DynamicNormalCache) and
use_new_esimd_sdp_fp16(q_len, kv_seq_len, self.head_dim, query_states)):
use_sdp(q_len, kv_seq_len, self.head_dim, query_states)):
import linear_q4_0
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
else:
if isinstance(past_key_value, DynamicFp8Cache):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,

View file

@ -42,8 +42,7 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
from ipex_llm.transformers.models.utils import rotate_half, SILU
from ipex_llm.transformers.models.utils import mlp_fusion_check
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
use_sdp_fp8
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
from ipex_llm.transformers.models.utils import use_decoding_fast_path
from ipex_llm.utils.common import invalidInputError, invalidOperationError
from ipex_llm.ggml.quantize import ggml_tensor_qtype
@ -291,9 +290,9 @@ def qwen_attention_forward_original(
attn_output = attn_output.transpose(1, 2)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
use_new_esimd_sdp_fp16(q_len, key.shape[2], self.head_dim, query):
use_sdp(q_len, key.shape[2], self.head_dim, query):
import linear_q4_0
attn_output = linear_q4_0.sdp_fp16(query, key, value, attention_mask)
attn_output = linear_q4_0.sdp(query, key, value, attention_mask)
attn_output = attn_output.view(query.shape)
attn_output = attn_output.transpose(1, 2)
attn_weight = None

View file

@ -52,7 +52,7 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.kv import DynamicFp8Cache
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
@ -565,11 +565,9 @@ def qwen2_attention_forward_origin(
is_causal=True)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states,
value_states)
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_q4_0
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
else:

View file

@ -32,7 +32,7 @@ import torch.utils.checkpoint
from transformers.utils import logging
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import rotate_half
from ipex_llm.transformers.models.utils import use_esimd_sdp
from ipex_llm.transformers.models.utils import use_sdp
from ipex_llm.transformers.models.utils import use_decoding_fast_path
import os
@ -207,11 +207,9 @@ def qwen_attention_forward_vl(
query = query.permute(0, 2, 1, 3)
if not self.training and not hidden_states.requires_grad and \
use_esimd_sdp(q_len, key.shape[2], self.head_dim, query):
import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query,
key,
value)
use_sdp(q_len, key.shape[2], self.head_dim, query):
import linear_q4_0
attn_output = linear_q4_0.sdp(query, key, value, attention_mask)
attn_output = attn_output.view(query.shape)
attn_output = attn_output.transpose(1, 2)
attn_weight = None

View file

@ -53,7 +53,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv
try:
from transformers.cache_utils import Cache
@ -266,11 +266,9 @@ def stablelm_attention_forward_original(
is_causal=True)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states, attention_mask):
import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states,
value_states)
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_q4_0
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
else:

View file

@ -375,20 +375,31 @@ def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
return True
def use_sdp_fp8(q_len, k_len, query_states):
if query_states.device.type != "xpu":
return False
if q_len == k_len:
# sdp_fp8 only support rest token now
return False
return True
def use_sdp(q_len, kv_len, head_dim, query_states):
return (
query_states.device.type == "xpu"
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
and head_dim in [64, 96, 128]
and q_len != kv_len # next token
and q_len <= 32 # lookup
)
def use_sdp_fp8(q_len, kv_len, query_states):
return (
query_states.device.type == "xpu"
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
and q_len != kv_len # next token
and q_len <= 32 # lookup
)
def use_sdp_causal(q_len, kv_len, query_states, training):
return (
q_len == kv_len # first token
and query_states.device.type == "xpu" # GPU
and not query_states.requires_grad and not training # no training
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
and not query_states.requires_grad and not training # not training
)