LLM: Update qkv fusion for GGUF-IQ2 (#10271)

* first commit

* update mistral

* fix transformers==4.36.0

* fix

* disable qk for mixtral now

* fix style
This commit is contained in:
Ruonan Wang 2024-02-29 12:49:53 +08:00 committed by GitHub
parent 6fb65bb9d2
commit a9fd20b6ba
4 changed files with 54 additions and 10 deletions

View file

@ -53,6 +53,7 @@ if __name__ == '__main__':
# https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/tree/main.
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit='gguf_iq2_xxs',
torch_dtype=torch.float16,
trust_remote_code=True,
imatrix='llama-v2-7b.imatrix').to("xpu")

View file

@ -48,7 +48,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xp
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from bigdl.llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError
@ -292,7 +292,7 @@ def llama_attention_forward_4_31_quantized(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5]
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
and enough_kv_room and bsz * q_len == 1)
@ -320,6 +320,7 @@ def llama_attention_forward_4_31_quantized(
position_ids,
tmp_cache_k, tmp_cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
0,
self.head_dim)
else:
@ -484,7 +485,7 @@ def llama_attention_forward_4_31_original(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5]
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
enough_kv_room and bsz * q_len == 1)
@ -507,6 +508,7 @@ def llama_attention_forward_4_31_original(
position_ids,
cache_k, cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
kv_seq_len += 1
@ -719,9 +721,10 @@ def llama_attention_selective_batching_forward_4_31(
# TODO: decoding fast path
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0])
is_q4_0 = self.q_proj.qtype == SYM_INT4
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
bsz * q_len == 1)
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
@ -756,6 +759,7 @@ def llama_attention_selective_batching_forward_4_31(
position_ids,
past_k, past_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
kv_seq_len += 1
@ -912,9 +916,9 @@ def llama_attention_forward_4_36(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
qtype = getattr(self.q_proj, "qtype", None)
is_q4_0 = qtype == SYM_INT4
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
enough_kv_room and bsz * q_len == 1)
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
@ -935,6 +939,7 @@ def llama_attention_forward_4_36(
position_ids,
cache_k, cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
kv_seq_len += 1

View file

@ -49,7 +49,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -77,7 +77,7 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
return q_type in [SYM_INT4, FP8E5] and \
return q_type in [SYM_INT4, FP8E5, IQ2_XXS] and \
use_fuse_rope and enough_kv_room and bs == 1
@ -188,6 +188,7 @@ def mistral_attention_forward_quantized(
position_ids,
tmp_cache_k, tmp_cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
0,
self.head_dim)
else:
@ -360,6 +361,7 @@ def mistral_attention_forward_original(
position_ids,
cache_k, cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
kv_seq_len += 1
@ -512,6 +514,7 @@ def mistral_attention_forward_4_36(
position_ids,
cache_k, cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
kv_seq_len += 1

View file

@ -51,6 +51,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from bigdl.llm.transformers.models.utils import mlp_fusion_check
from bigdl.llm.transformers.low_bit_linear import IQ2_XXS
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -155,7 +156,7 @@ def mixtral_attention_forward(
bsz * q_len)
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
if decoding_fast_path:
if decoding_fast_path and self.q_proj.qtype != IQ2_XXS:
hidden_states = hidden_states.view(1, -1)
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
@ -168,6 +169,7 @@ def mixtral_attention_forward(
position_ids,
cache_k, cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
kv_seq_len += 1
@ -176,7 +178,40 @@ def mixtral_attention_forward(
past_key_value.seen_tokens = kv_seq_len
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
# diasble it for now as it will cause output change for unknown reason
# elif decoding_fast_path and self.q_proj.qtype == IQ2_XXS:
# # this path self.v_proj use q4_0
# hidden_states = hidden_states.view(1, -1)
# cache_k = past_key_value.key_cache[self.layer_idx]
# cache_v = past_key_value.value_cache[self.layer_idx]
# kv_seq_len = cache_k.shape[-2]
# import linear_q4_0
# query_states, key_states = linear_q4_0.forward_qk(hidden_states,
# self.q_proj.weight,
# self.k_proj.weight,
# position_ids,
# cache_k,
# self.q_proj.weight.qtype,
# kv_seq_len,
# self.head_dim,
# 10000)
# kv_seq_len += 1
# # update past_key_value's seem_tokens and kv caches.
# if self.layer_idx == 0:
# past_key_value.seen_tokens = kv_seq_len
# # update value_states
# value_states = self.v_proj(hidden_states)
# value_states = value_states.view(bsz, q_len,
# self.num_key_value_heads, self.head_dim).transpose(1, 2)
# new_size = (cache_v.size(0),
# cache_v.size(1),
# cache_v.size(2) + value_states.size(2),
# cache_v.size(3))
# new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
# new_cache_v[:, :, cache_v.size(2):cache_v.size(2)+value_states.size(2), :] = value_states
# past_key_value.key_cache[self.layer_idx] = key_states
# past_key_value.value_cache[self.layer_idx] = new_cache_v
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)