add sdp fp8 for llama (#11671)
* add sdp fp8 for llama * fix style * refactor
This commit is contained in:
		
							parent
							
								
									7f88ce23cd
								
							
						
					
					
						commit
						c11d5301d7
					
				
					 1 changed files with 13 additions and 7 deletions
				
			
		| 
						 | 
					@ -46,7 +46,8 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
 | 
				
			||||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
					from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
				
			||||||
    apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
 | 
					    apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
 | 
				
			||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
					from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
 | 
					from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \
 | 
				
			||||||
 | 
					    use_sdp_causal
 | 
				
			||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
 | 
					from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
 | 
					from ipex_llm.transformers.models.utils import use_decoding_fast_path
 | 
				
			||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
					from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
				
			||||||
| 
						 | 
					@ -1678,8 +1679,17 @@ def llama_attention_forward_4_38_quantized(
 | 
				
			||||||
    if len(past_key_value.key_cache) <= self.layer_idx:
 | 
					    if len(past_key_value.key_cache) <= self.layer_idx:
 | 
				
			||||||
        repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
					        repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
				
			||||||
        repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
					        repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
				
			||||||
        if should_split_qkv_tensor(query_states, bsz, self.num_heads,
 | 
					        if use_cache:
 | 
				
			||||||
                                   q_len, kv_seq_len, output_attentions):
 | 
					            cache_kwargs = None
 | 
				
			||||||
 | 
					            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
 | 
					                                                             self.layer_idx, cache_kwargs)
 | 
				
			||||||
 | 
					        if use_cache and use_sdp_causal(q_len, kv_seq_len, self.head_dim,
 | 
				
			||||||
 | 
					                                        query_states, self.training):
 | 
				
			||||||
 | 
					            import xe_addons
 | 
				
			||||||
 | 
					            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
				
			||||||
 | 
					                                                   value_states, attention_mask)
 | 
				
			||||||
 | 
					        elif should_split_qkv_tensor(query_states, bsz, self.num_heads,
 | 
				
			||||||
 | 
					                                     q_len, kv_seq_len, output_attentions):
 | 
				
			||||||
            attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
 | 
					            attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
 | 
				
			||||||
                                                         repeated_value_states,
 | 
					                                                         repeated_value_states,
 | 
				
			||||||
                                                         attention_mask, cache_position,
 | 
					                                                         attention_mask, cache_position,
 | 
				
			||||||
| 
						 | 
					@ -1719,10 +1729,6 @@ def llama_attention_forward_4_38_quantized(
 | 
				
			||||||
                attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
					                attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
				
			||||||
                                                     dtype=torch.float32).to(query_states.dtype)
 | 
					                                                     dtype=torch.float32).to(query_states.dtype)
 | 
				
			||||||
            attn_output = torch.matmul(attn_weights, repeated_value_states)
 | 
					            attn_output = torch.matmul(attn_weights, repeated_value_states)
 | 
				
			||||||
        if use_cache:
 | 
					 | 
				
			||||||
            cache_kwargs = None
 | 
					 | 
				
			||||||
            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					 | 
				
			||||||
                                                             self.layer_idx, cache_kwargs)
 | 
					 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        cache_kwargs = None  # Specific to RoPE models
 | 
					        cache_kwargs = None  # Specific to RoPE models
 | 
				
			||||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue