add sdp fp8 for llama (#11671)

* add sdp fp8 for llama

* fix style

* refactor
This commit is contained in:
Ruonan Wang 2024-07-29 08:46:22 +03:00 committed by GitHub
parent 7f88ce23cd
commit c11d5301d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -46,7 +46,8 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8 from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \
use_sdp_causal
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.transformers.models.utils import use_decoding_fast_path
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
@ -1678,8 +1679,17 @@ def llama_attention_forward_4_38_quantized(
if len(past_key_value.key_cache) <= self.layer_idx: if len(past_key_value.key_cache) <= self.layer_idx:
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups) repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups) repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
if should_split_qkv_tensor(query_states, bsz, self.num_heads, if use_cache:
q_len, kv_seq_len, output_attentions): cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
if use_cache and use_sdp_causal(q_len, kv_seq_len, self.head_dim,
query_states, self.training):
import xe_addons
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
elif should_split_qkv_tensor(query_states, bsz, self.num_heads,
q_len, kv_seq_len, output_attentions):
attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states, attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
repeated_value_states, repeated_value_states,
attention_mask, cache_position, attention_mask, cache_position,
@ -1719,10 +1729,6 @@ def llama_attention_forward_4_38_quantized(
attn_weights = nn.functional.softmax(attn_weights, dim=-1, attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype) dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, repeated_value_states) attn_output = torch.matmul(attn_weights, repeated_value_states)
if use_cache:
cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
else: else:
cache_kwargs = None # Specific to RoPE models cache_kwargs = None # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,