LLM: Update qkv fusion for GGUF-IQ2 (#10271)
* first commit * update mistral * fix transformers==4.36.0 * fix * disable qk for mixtral now * fix style
This commit is contained in:
parent
6fb65bb9d2
commit
a9fd20b6ba
4 changed files with 54 additions and 10 deletions
|
|
@ -53,6 +53,7 @@ if __name__ == '__main__':
|
||||||
# https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/tree/main.
|
# https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/tree/main.
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
load_in_low_bit='gguf_iq2_xxs',
|
load_in_low_bit='gguf_iq2_xxs',
|
||||||
|
torch_dtype=torch.float16,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
imatrix='llama-v2-7b.imatrix').to("xpu")
|
imatrix='llama-v2-7b.imatrix').to("xpu")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xp
|
||||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
|
from bigdl.llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
@ -292,7 +292,7 @@ def llama_attention_forward_4_31_quantized(
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
|
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
|
||||||
qtype = getattr(self.q_proj, "qtype", None)
|
qtype = getattr(self.q_proj, "qtype", None)
|
||||||
qtype_check = qtype in [SYM_INT4, FP8E5]
|
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
|
||||||
no_tp = not self.config.pretraining_tp > 1
|
no_tp = not self.config.pretraining_tp > 1
|
||||||
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
|
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
|
||||||
and enough_kv_room and bsz * q_len == 1)
|
and enough_kv_room and bsz * q_len == 1)
|
||||||
|
|
@ -320,6 +320,7 @@ def llama_attention_forward_4_31_quantized(
|
||||||
position_ids,
|
position_ids,
|
||||||
tmp_cache_k, tmp_cache_v,
|
tmp_cache_k, tmp_cache_v,
|
||||||
self.q_proj.weight.qtype,
|
self.q_proj.weight.qtype,
|
||||||
|
self.v_proj.weight.qtype,
|
||||||
0,
|
0,
|
||||||
self.head_dim)
|
self.head_dim)
|
||||||
else:
|
else:
|
||||||
|
|
@ -484,7 +485,7 @@ def llama_attention_forward_4_31_original(
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
|
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
|
||||||
qtype = getattr(self.q_proj, "qtype", None)
|
qtype = getattr(self.q_proj, "qtype", None)
|
||||||
qtype_check = qtype in [SYM_INT4, FP8E5]
|
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
|
||||||
no_tp = not self.config.pretraining_tp > 1
|
no_tp = not self.config.pretraining_tp > 1
|
||||||
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
|
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
|
||||||
enough_kv_room and bsz * q_len == 1)
|
enough_kv_room and bsz * q_len == 1)
|
||||||
|
|
@ -507,6 +508,7 @@ def llama_attention_forward_4_31_original(
|
||||||
position_ids,
|
position_ids,
|
||||||
cache_k, cache_v,
|
cache_k, cache_v,
|
||||||
self.q_proj.weight.qtype,
|
self.q_proj.weight.qtype,
|
||||||
|
self.v_proj.weight.qtype,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
self.head_dim)
|
self.head_dim)
|
||||||
kv_seq_len += 1
|
kv_seq_len += 1
|
||||||
|
|
@ -719,9 +721,10 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
# TODO: decoding fast path
|
# TODO: decoding fast path
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0])
|
enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0])
|
||||||
is_q4_0 = self.q_proj.qtype == SYM_INT4
|
qtype = getattr(self.q_proj, "qtype", None)
|
||||||
|
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
|
||||||
no_tp = not self.config.pretraining_tp > 1
|
no_tp = not self.config.pretraining_tp > 1
|
||||||
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
|
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
|
||||||
bsz * q_len == 1)
|
bsz * q_len == 1)
|
||||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||||
|
|
||||||
|
|
@ -756,6 +759,7 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
position_ids,
|
position_ids,
|
||||||
past_k, past_v,
|
past_k, past_v,
|
||||||
self.q_proj.weight.qtype,
|
self.q_proj.weight.qtype,
|
||||||
|
self.v_proj.weight.qtype,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
self.head_dim)
|
self.head_dim)
|
||||||
kv_seq_len += 1
|
kv_seq_len += 1
|
||||||
|
|
@ -912,9 +916,9 @@ def llama_attention_forward_4_36(
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||||
qtype = getattr(self.q_proj, "qtype", None)
|
qtype = getattr(self.q_proj, "qtype", None)
|
||||||
is_q4_0 = qtype == SYM_INT4
|
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
|
||||||
no_tp = not self.config.pretraining_tp > 1
|
no_tp = not self.config.pretraining_tp > 1
|
||||||
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
|
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
|
||||||
enough_kv_room and bsz * q_len == 1)
|
enough_kv_room and bsz * q_len == 1)
|
||||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||||
|
|
||||||
|
|
@ -935,6 +939,7 @@ def llama_attention_forward_4_36(
|
||||||
position_ids,
|
position_ids,
|
||||||
cache_k, cache_v,
|
cache_k, cache_v,
|
||||||
self.q_proj.weight.qtype,
|
self.q_proj.weight.qtype,
|
||||||
|
self.v_proj.weight.qtype,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
self.head_dim)
|
self.head_dim)
|
||||||
kv_seq_len += 1
|
kv_seq_len += 1
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
|
||||||
apply_rotary_pos_emb_no_cache_xpu
|
apply_rotary_pos_emb_no_cache_xpu
|
||||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||||
is_enough_kv_cache_room_4_36
|
is_enough_kv_cache_room_4_36
|
||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
|
||||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -77,7 +77,7 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
|
||||||
|
|
||||||
|
|
||||||
def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
|
def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
|
||||||
return q_type in [SYM_INT4, FP8E5] and \
|
return q_type in [SYM_INT4, FP8E5, IQ2_XXS] and \
|
||||||
use_fuse_rope and enough_kv_room and bs == 1
|
use_fuse_rope and enough_kv_room and bs == 1
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -188,6 +188,7 @@ def mistral_attention_forward_quantized(
|
||||||
position_ids,
|
position_ids,
|
||||||
tmp_cache_k, tmp_cache_v,
|
tmp_cache_k, tmp_cache_v,
|
||||||
self.q_proj.weight.qtype,
|
self.q_proj.weight.qtype,
|
||||||
|
self.v_proj.weight.qtype,
|
||||||
0,
|
0,
|
||||||
self.head_dim)
|
self.head_dim)
|
||||||
else:
|
else:
|
||||||
|
|
@ -360,6 +361,7 @@ def mistral_attention_forward_original(
|
||||||
position_ids,
|
position_ids,
|
||||||
cache_k, cache_v,
|
cache_k, cache_v,
|
||||||
self.q_proj.weight.qtype,
|
self.q_proj.weight.qtype,
|
||||||
|
self.v_proj.weight.qtype,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
self.head_dim)
|
self.head_dim)
|
||||||
kv_seq_len += 1
|
kv_seq_len += 1
|
||||||
|
|
@ -512,6 +514,7 @@ def mistral_attention_forward_4_36(
|
||||||
position_ids,
|
position_ids,
|
||||||
cache_k, cache_v,
|
cache_k, cache_v,
|
||||||
self.q_proj.weight.qtype,
|
self.q_proj.weight.qtype,
|
||||||
|
self.v_proj.weight.qtype,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
self.head_dim)
|
self.head_dim)
|
||||||
kv_seq_len += 1
|
kv_seq_len += 1
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
|
||||||
from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
|
from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
|
||||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
||||||
|
from bigdl.llm.transformers.low_bit_linear import IQ2_XXS
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -155,7 +156,7 @@ def mixtral_attention_forward(
|
||||||
bsz * q_len)
|
bsz * q_len)
|
||||||
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||||
|
|
||||||
if decoding_fast_path:
|
if decoding_fast_path and self.q_proj.qtype != IQ2_XXS:
|
||||||
hidden_states = hidden_states.view(1, -1)
|
hidden_states = hidden_states.view(1, -1)
|
||||||
cache_k = past_key_value.key_cache[self.layer_idx]
|
cache_k = past_key_value.key_cache[self.layer_idx]
|
||||||
cache_v = past_key_value.value_cache[self.layer_idx]
|
cache_v = past_key_value.value_cache[self.layer_idx]
|
||||||
|
|
@ -168,6 +169,7 @@ def mixtral_attention_forward(
|
||||||
position_ids,
|
position_ids,
|
||||||
cache_k, cache_v,
|
cache_k, cache_v,
|
||||||
self.q_proj.weight.qtype,
|
self.q_proj.weight.qtype,
|
||||||
|
self.v_proj.weight.qtype,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
self.head_dim)
|
self.head_dim)
|
||||||
kv_seq_len += 1
|
kv_seq_len += 1
|
||||||
|
|
@ -176,7 +178,40 @@ def mixtral_attention_forward(
|
||||||
past_key_value.seen_tokens = kv_seq_len
|
past_key_value.seen_tokens = kv_seq_len
|
||||||
past_key_value.key_cache[self.layer_idx] = key_states
|
past_key_value.key_cache[self.layer_idx] = key_states
|
||||||
past_key_value.value_cache[self.layer_idx] = value_states
|
past_key_value.value_cache[self.layer_idx] = value_states
|
||||||
|
# diasble it for now as it will cause output change for unknown reason
|
||||||
|
# elif decoding_fast_path and self.q_proj.qtype == IQ2_XXS:
|
||||||
|
# # this path self.v_proj use q4_0
|
||||||
|
# hidden_states = hidden_states.view(1, -1)
|
||||||
|
# cache_k = past_key_value.key_cache[self.layer_idx]
|
||||||
|
# cache_v = past_key_value.value_cache[self.layer_idx]
|
||||||
|
# kv_seq_len = cache_k.shape[-2]
|
||||||
|
# import linear_q4_0
|
||||||
|
# query_states, key_states = linear_q4_0.forward_qk(hidden_states,
|
||||||
|
# self.q_proj.weight,
|
||||||
|
# self.k_proj.weight,
|
||||||
|
# position_ids,
|
||||||
|
# cache_k,
|
||||||
|
# self.q_proj.weight.qtype,
|
||||||
|
# kv_seq_len,
|
||||||
|
# self.head_dim,
|
||||||
|
# 10000)
|
||||||
|
# kv_seq_len += 1
|
||||||
|
# # update past_key_value's seem_tokens and kv caches.
|
||||||
|
# if self.layer_idx == 0:
|
||||||
|
# past_key_value.seen_tokens = kv_seq_len
|
||||||
|
# # update value_states
|
||||||
|
# value_states = self.v_proj(hidden_states)
|
||||||
|
# value_states = value_states.view(bsz, q_len,
|
||||||
|
# self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
# new_size = (cache_v.size(0),
|
||||||
|
# cache_v.size(1),
|
||||||
|
# cache_v.size(2) + value_states.size(2),
|
||||||
|
# cache_v.size(3))
|
||||||
|
# new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
|
||||||
|
# new_cache_v[:, :, cache_v.size(2):cache_v.size(2)+value_states.size(2), :] = value_states
|
||||||
|
|
||||||
|
# past_key_value.key_cache[self.layer_idx] = key_states
|
||||||
|
# past_key_value.value_cache[self.layer_idx] = new_cache_v
|
||||||
else:
|
else:
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue