LLM: support qkv fusion for fp8e5 (#9878)

* update

* add mistral

* meet code review
This commit is contained in:
Ruonan Wang 2024-01-10 17:50:00 +08:00 committed by GitHub
parent cb32b985ec
commit 53531ae4ee
2 changed files with 6 additions and 5 deletions

View file

@ -46,7 +46,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xp
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from bigdl.llm.transformers.models.utils import mlp_fusion_check from bigdl.llm.transformers.models.utils import mlp_fusion_check
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.low_bit_linear import SYM_INT4 from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
@ -144,9 +144,9 @@ def llama_attention_forward_4_31(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
qtype = getattr(self.q_proj, "qtype", None) qtype = getattr(self.q_proj, "qtype", None)
is_q4_0 = qtype == SYM_INT4 qtype_check = qtype in [SYM_INT4, FP8E5]
no_tp = not self.config.pretraining_tp > 1 no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
enough_kv_room and bsz * q_len == 1) enough_kv_room and bsz * q_len == 1)
# single batch decoding fast path # single batch decoding fast path

View file

@ -47,7 +47,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
apply_rotary_pos_emb_no_cache_xpu apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\
is_enough_kv_cache_room_4_36 is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.low_bit_linear import SYM_INT4 from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from bigdl.llm.transformers.models.utils import use_flash_attention from bigdl.llm.transformers.models.utils import use_flash_attention
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -75,7 +75,8 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs): def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
return q_type == SYM_INT4 and use_fuse_rope and enough_kv_room and bs == 1 return q_type in [SYM_INT4, FP8E5] and \
use_fuse_rope and enough_kv_room and bs == 1
def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len, def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len,