From a9fd20b6ba885f20b4cc170d65b6771df059623f Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Thu, 29 Feb 2024 12:49:53 +0800 Subject: [PATCH] LLM: Update qkv fusion for GGUF-IQ2 (#10271) * first commit * update mistral * fix transformers==4.36.0 * fix * disable qk for mixtral now * fix style --- .../GGUF-IQ2/generate.py | 1 + .../bigdl/llm/transformers/models/llama.py | 19 ++++++---- .../bigdl/llm/transformers/models/mistral.py | 7 +++- .../bigdl/llm/transformers/models/mixtral.py | 37 ++++++++++++++++++- 4 files changed, 54 insertions(+), 10 deletions(-) diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF-IQ2/generate.py b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF-IQ2/generate.py index c85a8473..f6ca2511 100644 --- a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF-IQ2/generate.py +++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF-IQ2/generate.py @@ -53,6 +53,7 @@ if __name__ == '__main__': # https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/tree/main. model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='gguf_iq2_xxs', + torch_dtype=torch.float16, trust_remote_code=True, imatrix='llama-v2-7b.imatrix').to("xpu") diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 7dabab3a..b58c23eb 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -48,7 +48,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xp from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from bigdl.llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check from transformers.modeling_outputs import BaseModelOutputWithPast -from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5 +from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError @@ -292,7 +292,7 @@ def llama_attention_forward_4_31_quantized( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) qtype = getattr(self.q_proj, "qtype", None) - qtype_check = qtype in [SYM_INT4, FP8E5] + qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS] no_tp = not self.config.pretraining_tp > 1 decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and enough_kv_room and bsz * q_len == 1) @@ -320,6 +320,7 @@ def llama_attention_forward_4_31_quantized( position_ids, tmp_cache_k, tmp_cache_v, self.q_proj.weight.qtype, + self.v_proj.weight.qtype, 0, self.head_dim) else: @@ -484,7 +485,7 @@ def llama_attention_forward_4_31_original( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) qtype = getattr(self.q_proj, "qtype", None) - qtype_check = qtype in [SYM_INT4, FP8E5] + qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS] no_tp = not self.config.pretraining_tp > 1 decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and enough_kv_room and bsz * q_len == 1) @@ -507,6 +508,7 @@ def llama_attention_forward_4_31_original( position_ids, cache_k, cache_v, self.q_proj.weight.qtype, + self.v_proj.weight.qtype, kv_seq_len, self.head_dim) kv_seq_len += 1 @@ -719,9 +721,10 @@ def llama_attention_selective_batching_forward_4_31( # TODO: decoding fast path use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0]) - is_q4_0 = self.q_proj.qtype == SYM_INT4 + qtype = getattr(self.q_proj, "qtype", None) + qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS] no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and + decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and bsz * q_len == 1) decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla @@ -756,6 +759,7 @@ def llama_attention_selective_batching_forward_4_31( position_ids, past_k, past_v, self.q_proj.weight.qtype, + self.v_proj.weight.qtype, kv_seq_len, self.head_dim) kv_seq_len += 1 @@ -912,9 +916,9 @@ def llama_attention_forward_4_36( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) qtype = getattr(self.q_proj, "qtype", None) - is_q4_0 = qtype == SYM_INT4 + qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS] no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and + decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and enough_kv_room and bsz * q_len == 1) decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla @@ -935,6 +939,7 @@ def llama_attention_forward_4_36( position_ids, cache_k, cache_v, self.q_proj.weight.qtype, + self.v_proj.weight.qtype, kv_seq_len, self.head_dim) kv_seq_len += 1 diff --git a/python/llm/src/bigdl/llm/transformers/models/mistral.py b/python/llm/src/bigdl/llm/transformers/models/mistral.py index 5c740c19..b79d053a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mistral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mistral.py @@ -49,7 +49,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \ apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ is_enough_kv_cache_room_4_36 -from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5 +from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -77,7 +77,7 @@ def should_use_fuse_rope(self, hidden_states, position_ids): def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs): - return q_type in [SYM_INT4, FP8E5] and \ + return q_type in [SYM_INT4, FP8E5, IQ2_XXS] and \ use_fuse_rope and enough_kv_room and bs == 1 @@ -188,6 +188,7 @@ def mistral_attention_forward_quantized( position_ids, tmp_cache_k, tmp_cache_v, self.q_proj.weight.qtype, + self.v_proj.weight.qtype, 0, self.head_dim) else: @@ -360,6 +361,7 @@ def mistral_attention_forward_original( position_ids, cache_k, cache_v, self.q_proj.weight.qtype, + self.v_proj.weight.qtype, kv_seq_len, self.head_dim) kv_seq_len += 1 @@ -512,6 +514,7 @@ def mistral_attention_forward_4_36( position_ids, cache_k, cache_v, self.q_proj.weight.qtype, + self.v_proj.weight.qtype, kv_seq_len, self.head_dim) kv_seq_len += 1 diff --git a/python/llm/src/bigdl/llm/transformers/models/mixtral.py b/python/llm/src/bigdl/llm/transformers/models/mixtral.py index c12887b5..53b8e114 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mixtral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mixtral.py @@ -51,6 +51,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from bigdl.llm.transformers.models.utils import mlp_fusion_check +from bigdl.llm.transformers.low_bit_linear import IQ2_XXS KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -155,7 +156,7 @@ def mixtral_attention_forward( bsz * q_len) decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla - if decoding_fast_path: + if decoding_fast_path and self.q_proj.qtype != IQ2_XXS: hidden_states = hidden_states.view(1, -1) cache_k = past_key_value.key_cache[self.layer_idx] cache_v = past_key_value.value_cache[self.layer_idx] @@ -168,6 +169,7 @@ def mixtral_attention_forward( position_ids, cache_k, cache_v, self.q_proj.weight.qtype, + self.v_proj.weight.qtype, kv_seq_len, self.head_dim) kv_seq_len += 1 @@ -176,7 +178,40 @@ def mixtral_attention_forward( past_key_value.seen_tokens = kv_seq_len past_key_value.key_cache[self.layer_idx] = key_states past_key_value.value_cache[self.layer_idx] = value_states + # diasble it for now as it will cause output change for unknown reason + # elif decoding_fast_path and self.q_proj.qtype == IQ2_XXS: + # # this path self.v_proj use q4_0 + # hidden_states = hidden_states.view(1, -1) + # cache_k = past_key_value.key_cache[self.layer_idx] + # cache_v = past_key_value.value_cache[self.layer_idx] + # kv_seq_len = cache_k.shape[-2] + # import linear_q4_0 + # query_states, key_states = linear_q4_0.forward_qk(hidden_states, + # self.q_proj.weight, + # self.k_proj.weight, + # position_ids, + # cache_k, + # self.q_proj.weight.qtype, + # kv_seq_len, + # self.head_dim, + # 10000) + # kv_seq_len += 1 + # # update past_key_value's seem_tokens and kv caches. + # if self.layer_idx == 0: + # past_key_value.seen_tokens = kv_seq_len + # # update value_states + # value_states = self.v_proj(hidden_states) + # value_states = value_states.view(bsz, q_len, + # self.num_key_value_heads, self.head_dim).transpose(1, 2) + # new_size = (cache_v.size(0), + # cache_v.size(1), + # cache_v.size(2) + value_states.size(2), + # cache_v.size(3)) + # new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0) + # new_cache_v[:, :, cache_v.size(2):cache_v.size(2)+value_states.size(2), :] = value_states + # past_key_value.key_cache[self.layer_idx] = key_states + # past_key_value.value_cache[self.layer_idx] = new_cache_v else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states)