LLM: Update qkv fusion for GGUF-IQ2 (#10271)
* first commit * update mistral * fix transformers==4.36.0 * fix * disable qk for mixtral now * fix style
This commit is contained in:
parent
6fb65bb9d2
commit
a9fd20b6ba
4 changed files with 54 additions and 10 deletions
|
|
@ -53,6 +53,7 @@ if __name__ == '__main__':
|
|||
# https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/tree/main.
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
load_in_low_bit='gguf_iq2_xxs',
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=True,
|
||||
imatrix='llama-v2-7b.imatrix').to("xpu")
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xp
|
|||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
|
|
@ -292,7 +292,7 @@ def llama_attention_forward_4_31_quantized(
|
|||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
|
||||
qtype = getattr(self.q_proj, "qtype", None)
|
||||
qtype_check = qtype in [SYM_INT4, FP8E5]
|
||||
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
|
||||
and enough_kv_room and bsz * q_len == 1)
|
||||
|
|
@ -320,6 +320,7 @@ def llama_attention_forward_4_31_quantized(
|
|||
position_ids,
|
||||
tmp_cache_k, tmp_cache_v,
|
||||
self.q_proj.weight.qtype,
|
||||
self.v_proj.weight.qtype,
|
||||
0,
|
||||
self.head_dim)
|
||||
else:
|
||||
|
|
@ -484,7 +485,7 @@ def llama_attention_forward_4_31_original(
|
|||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
|
||||
qtype = getattr(self.q_proj, "qtype", None)
|
||||
qtype_check = qtype in [SYM_INT4, FP8E5]
|
||||
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
|
||||
enough_kv_room and bsz * q_len == 1)
|
||||
|
|
@ -507,6 +508,7 @@ def llama_attention_forward_4_31_original(
|
|||
position_ids,
|
||||
cache_k, cache_v,
|
||||
self.q_proj.weight.qtype,
|
||||
self.v_proj.weight.qtype,
|
||||
kv_seq_len,
|
||||
self.head_dim)
|
||||
kv_seq_len += 1
|
||||
|
|
@ -719,9 +721,10 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
# TODO: decoding fast path
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0])
|
||||
is_q4_0 = self.q_proj.qtype == SYM_INT4
|
||||
qtype = getattr(self.q_proj, "qtype", None)
|
||||
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
|
||||
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
|
||||
bsz * q_len == 1)
|
||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||
|
||||
|
|
@ -756,6 +759,7 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
position_ids,
|
||||
past_k, past_v,
|
||||
self.q_proj.weight.qtype,
|
||||
self.v_proj.weight.qtype,
|
||||
kv_seq_len,
|
||||
self.head_dim)
|
||||
kv_seq_len += 1
|
||||
|
|
@ -912,9 +916,9 @@ def llama_attention_forward_4_36(
|
|||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||
qtype = getattr(self.q_proj, "qtype", None)
|
||||
is_q4_0 = qtype == SYM_INT4
|
||||
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
|
||||
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
|
||||
enough_kv_room and bsz * q_len == 1)
|
||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||
|
||||
|
|
@ -935,6 +939,7 @@ def llama_attention_forward_4_36(
|
|||
position_ids,
|
||||
cache_k, cache_v,
|
||||
self.q_proj.weight.qtype,
|
||||
self.v_proj.weight.qtype,
|
||||
kv_seq_len,
|
||||
self.head_dim)
|
||||
kv_seq_len += 1
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
|
|||
apply_rotary_pos_emb_no_cache_xpu
|
||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||
is_enough_kv_cache_room_4_36
|
||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
|
||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
|
@ -77,7 +77,7 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
|
|||
|
||||
|
||||
def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
|
||||
return q_type in [SYM_INT4, FP8E5] and \
|
||||
return q_type in [SYM_INT4, FP8E5, IQ2_XXS] and \
|
||||
use_fuse_rope and enough_kv_room and bs == 1
|
||||
|
||||
|
||||
|
|
@ -188,6 +188,7 @@ def mistral_attention_forward_quantized(
|
|||
position_ids,
|
||||
tmp_cache_k, tmp_cache_v,
|
||||
self.q_proj.weight.qtype,
|
||||
self.v_proj.weight.qtype,
|
||||
0,
|
||||
self.head_dim)
|
||||
else:
|
||||
|
|
@ -360,6 +361,7 @@ def mistral_attention_forward_original(
|
|||
position_ids,
|
||||
cache_k, cache_v,
|
||||
self.q_proj.weight.qtype,
|
||||
self.v_proj.weight.qtype,
|
||||
kv_seq_len,
|
||||
self.head_dim)
|
||||
kv_seq_len += 1
|
||||
|
|
@ -512,6 +514,7 @@ def mistral_attention_forward_4_36(
|
|||
position_ids,
|
||||
cache_k, cache_v,
|
||||
self.q_proj.weight.qtype,
|
||||
self.v_proj.weight.qtype,
|
||||
kv_seq_len,
|
||||
self.head_dim)
|
||||
kv_seq_len += 1
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
|
|||
from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
|
||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
||||
from bigdl.llm.transformers.low_bit_linear import IQ2_XXS
|
||||
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
|
@ -155,7 +156,7 @@ def mixtral_attention_forward(
|
|||
bsz * q_len)
|
||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||
|
||||
if decoding_fast_path:
|
||||
if decoding_fast_path and self.q_proj.qtype != IQ2_XXS:
|
||||
hidden_states = hidden_states.view(1, -1)
|
||||
cache_k = past_key_value.key_cache[self.layer_idx]
|
||||
cache_v = past_key_value.value_cache[self.layer_idx]
|
||||
|
|
@ -168,6 +169,7 @@ def mixtral_attention_forward(
|
|||
position_ids,
|
||||
cache_k, cache_v,
|
||||
self.q_proj.weight.qtype,
|
||||
self.v_proj.weight.qtype,
|
||||
kv_seq_len,
|
||||
self.head_dim)
|
||||
kv_seq_len += 1
|
||||
|
|
@ -176,7 +178,40 @@ def mixtral_attention_forward(
|
|||
past_key_value.seen_tokens = kv_seq_len
|
||||
past_key_value.key_cache[self.layer_idx] = key_states
|
||||
past_key_value.value_cache[self.layer_idx] = value_states
|
||||
# diasble it for now as it will cause output change for unknown reason
|
||||
# elif decoding_fast_path and self.q_proj.qtype == IQ2_XXS:
|
||||
# # this path self.v_proj use q4_0
|
||||
# hidden_states = hidden_states.view(1, -1)
|
||||
# cache_k = past_key_value.key_cache[self.layer_idx]
|
||||
# cache_v = past_key_value.value_cache[self.layer_idx]
|
||||
# kv_seq_len = cache_k.shape[-2]
|
||||
# import linear_q4_0
|
||||
# query_states, key_states = linear_q4_0.forward_qk(hidden_states,
|
||||
# self.q_proj.weight,
|
||||
# self.k_proj.weight,
|
||||
# position_ids,
|
||||
# cache_k,
|
||||
# self.q_proj.weight.qtype,
|
||||
# kv_seq_len,
|
||||
# self.head_dim,
|
||||
# 10000)
|
||||
# kv_seq_len += 1
|
||||
# # update past_key_value's seem_tokens and kv caches.
|
||||
# if self.layer_idx == 0:
|
||||
# past_key_value.seen_tokens = kv_seq_len
|
||||
# # update value_states
|
||||
# value_states = self.v_proj(hidden_states)
|
||||
# value_states = value_states.view(bsz, q_len,
|
||||
# self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
# new_size = (cache_v.size(0),
|
||||
# cache_v.size(1),
|
||||
# cache_v.size(2) + value_states.size(2),
|
||||
# cache_v.size(3))
|
||||
# new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
|
||||
# new_cache_v[:, :, cache_v.size(2):cache_v.size(2)+value_states.size(2), :] = value_states
|
||||
|
||||
# past_key_value.key_cache[self.layer_idx] = key_states
|
||||
# past_key_value.value_cache[self.layer_idx] = new_cache_v
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
|
|
|
|||
Loading…
Reference in a new issue