LLM: support qkv fusion for fp8e5 (#9878)
* update * add mistral * meet code review
This commit is contained in:
		
							parent
							
								
									cb32b985ec
								
							
						
					
					
						commit
						53531ae4ee
					
				
					 2 changed files with 6 additions and 5 deletions
				
			
		| 
						 | 
				
			
			@ -46,7 +46,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xp
 | 
			
		|||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
 | 
			
		||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -144,9 +144,9 @@ def llama_attention_forward_4_31(
 | 
			
		|||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
 | 
			
		||||
    qtype = getattr(self.q_proj, "qtype", None)
 | 
			
		||||
    is_q4_0 = qtype == SYM_INT4
 | 
			
		||||
    qtype_check = qtype in [SYM_INT4, FP8E5]
 | 
			
		||||
    no_tp = not self.config.pretraining_tp > 1
 | 
			
		||||
    decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
 | 
			
		||||
    decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
 | 
			
		||||
                          enough_kv_room and bsz * q_len == 1)
 | 
			
		||||
 | 
			
		||||
    # single batch decoding fast path
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -47,7 +47,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
 | 
			
		|||
    apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\
 | 
			
		||||
    is_enough_kv_cache_room_4_36
 | 
			
		||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
 | 
			
		||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
 | 
			
		||||
from bigdl.llm.transformers.models.utils import use_flash_attention
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
| 
						 | 
				
			
			@ -75,7 +75,8 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
 | 
			
		||||
    return q_type == SYM_INT4 and use_fuse_rope and enough_kv_room and bs == 1
 | 
			
		||||
    return q_type in [SYM_INT4, FP8E5] and \
 | 
			
		||||
        use_fuse_rope and enough_kv_room and bs == 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue