LLM: support qkv fusion for fp8e5 (#9878)
* update * add mistral * meet code review
This commit is contained in:
parent
cb32b985ec
commit
53531ae4ee
2 changed files with 6 additions and 5 deletions
|
|
@ -46,7 +46,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xp
|
||||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
|
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
@ -144,9 +144,9 @@ def llama_attention_forward_4_31(
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
|
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
|
||||||
qtype = getattr(self.q_proj, "qtype", None)
|
qtype = getattr(self.q_proj, "qtype", None)
|
||||||
is_q4_0 = qtype == SYM_INT4
|
qtype_check = qtype in [SYM_INT4, FP8E5]
|
||||||
no_tp = not self.config.pretraining_tp > 1
|
no_tp = not self.config.pretraining_tp > 1
|
||||||
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
|
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
|
||||||
enough_kv_room and bsz * q_len == 1)
|
enough_kv_room and bsz * q_len == 1)
|
||||||
|
|
||||||
# single batch decoding fast path
|
# single batch decoding fast path
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
|
||||||
apply_rotary_pos_emb_no_cache_xpu
|
apply_rotary_pos_emb_no_cache_xpu
|
||||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\
|
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\
|
||||||
is_enough_kv_cache_room_4_36
|
is_enough_kv_cache_room_4_36
|
||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
|
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
||||||
from bigdl.llm.transformers.models.utils import use_flash_attention
|
from bigdl.llm.transformers.models.utils import use_flash_attention
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -75,7 +75,8 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
|
||||||
|
|
||||||
|
|
||||||
def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
|
def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
|
||||||
return q_type == SYM_INT4 and use_fuse_rope and enough_kv_room and bs == 1
|
return q_type in [SYM_INT4, FP8E5] and \
|
||||||
|
use_fuse_rope and enough_kv_room and bs == 1
|
||||||
|
|
||||||
|
|
||||||
def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len,
|
def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue