diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 181caacf..836bc2bb 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -46,7 +46,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xp from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from bigdl.llm.transformers.models.utils import mlp_fusion_check from transformers.modeling_outputs import BaseModelOutputWithPast -from bigdl.llm.transformers.low_bit_linear import SYM_INT4 +from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5 from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError @@ -144,9 +144,9 @@ def llama_attention_forward_4_31( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) qtype = getattr(self.q_proj, "qtype", None) - is_q4_0 = qtype == SYM_INT4 + qtype_check = qtype in [SYM_INT4, FP8E5] no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and + decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and enough_kv_room and bsz * q_len == 1) # single batch decoding fast path diff --git a/python/llm/src/bigdl/llm/transformers/models/mistral.py b/python/llm/src/bigdl/llm/transformers/models/mistral.py index e769e51c..3c653c2a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mistral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mistral.py @@ -47,7 +47,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\ is_enough_kv_cache_room_4_36 -from bigdl.llm.transformers.low_bit_linear import SYM_INT4 +from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5 from bigdl.llm.transformers.models.utils import use_flash_attention KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -75,7 +75,8 @@ def should_use_fuse_rope(self, hidden_states, position_ids): def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs): - return q_type == SYM_INT4 and use_fuse_rope and enough_kv_room and bs == 1 + return q_type in [SYM_INT4, FP8E5] and \ + use_fuse_rope and enough_kv_room and bs == 1 def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len,