From d848efe17cbf2c9345209fac007aa2bd110fdb03 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 8 Feb 2024 16:17:21 +0800 Subject: [PATCH] add quantize kv cache support for qwen2 (#10134) --- .../llm/src/bigdl/llm/transformers/convert.py | 4 + python/llm/src/bigdl/llm/transformers/kv.py | 56 +++++ .../bigdl/llm/transformers/models/qwen2.py | 192 ++++++++++++++++-- 3 files changed, 230 insertions(+), 22 deletions(-) create mode 100644 python/llm/src/bigdl/llm/transformers/kv.py diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index d2dfa651..f299cb78 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -893,10 +893,14 @@ def _optimize_post(model, lightweight_bmm=False): # for Qwen1.5-7B modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.qwen2 import qwen2_model_forward from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward # TODO: add these optimization back # RMSNorm and rotray embedding are disabled for now # as they lead to obvious performance drop for Qwen 1.5 + convert_forward(model, + module.Qwen2Model, + qwen2_model_forward) convert_forward(model, module.Qwen2Attention, qwen2_attention_forward diff --git a/python/llm/src/bigdl/llm/transformers/kv.py b/python/llm/src/bigdl/llm/transformers/kv.py new file mode 100644 index 00000000..0d3ad897 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/kv.py @@ -0,0 +1,56 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import torch + +from .models.utils import init_fp8_kv_cache, append_fp8_kv_cache +from typing import Optional, Dict, Tuple, Any +from transformers.cache_utils import DynamicCache + + +class DynamicFp8Cache(DynamicCache): + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]]=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + batch_size, num_heads, seq_len, head_dim = key_states.shape + + if layer_idx == 0: + self.seen_tokens += seq_len + + # Update the cache + if len(self.key_cache) <= layer_idx: + k_cache, v_cache = init_fp8_kv_cache( + batch_size, num_heads, seq_len, head_dim, + device=key_states.device, + ) + k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states) + + self.key_cache.append(k_cache) + self.value_cache.append(v_cache) + else: + k_cache = self.key_cache[layer_idx] + v_cache = self.value_cache[layer_idx] + k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states) + self.key_cache[layer_idx] = k_cache + self.value_cache[layer_idx] = v_cache + + return self.key_cache[layer_idx], self.value_cache[layer_idx] diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen2.py b/python/llm/src/bigdl/llm/transformers/models/qwen2.py index 9d9a872e..de9ccb61 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen2.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen2.py @@ -46,9 +46,11 @@ import torch.nn as nn from bigdl.llm.transformers.models.llama import repeat_kv from bigdl.llm.transformers.models.utils import extend_kv_cache, append_kv_cache -from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \ - apply_rotary_pos_emb_no_cache_xpu, is_enough_kv_cache_room_4_36 +from bigdl.llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache +from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36 +from bigdl.llm.transformers.kv import DynamicFp8Cache from bigdl.llm.utils.common import invalidInputError +from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -61,6 +63,36 @@ def should_use_fuse_rope(self, query_states, position_ids): return use_fuse_rope +def qwen2_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +): + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids): + if not isinstance(past_key_values, DynamicFp8Cache): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + return Qwen2Model.forward( + self=self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + def qwen2_attention_forward( self, hidden_states: torch.Tensor, @@ -71,6 +103,128 @@ def qwen2_attention_forward( use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if use_quantize_kv_cache(self.q_proj, hidden_states): + forward_function = qwen2_attention_forward_quantized + else: + forward_function = qwen2_attention_forward_origin + return forward_function( + self=self, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + + +def qwen2_attention_forward_quantized( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[DynamicFp8Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. " + "Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, + self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + invalidInputError(self.layer_idx is not None, + "The cache structure has changed since version v4.36. " + f"If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, " + "please make sure to initialize the attention class " + "with a layer index.") + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) + + if q_len != 1: + key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) + attn_weights = torch.matmul(query_states, key.transpose(2, 3)) + else: + import linear_q4_0 + attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states) + + attn_weights = attn_weights / math.sqrt(self.head_dim) + + invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len), + ("Attention weights should be of size " + f"{(bsz, self.num_heads, q_len, kv_seq_len)}," + "but is {attn_weights.size()}")) + + if attention_mask is not None: + invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), + (f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}" + f" but is {attention_mask.size()}")) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, + training=self.training) + + if q_len != 1: + attn_output = torch.matmul(attn_weights, value) + else: + import linear_q4_0 + attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, + value_states.transpose(-1, -2)) + + invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), + "`attn_output` should be of size " + f"{(bsz, self.num_heads, q_len, self.head_dim)}," + f" but is {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def qwen2_attention_forward_origin( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) @@ -106,13 +260,14 @@ def qwen2_attention_forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) if past_key_value is not None: # update the number of seen tokens if self.layer_idx == 0: past_key_value.seen_tokens += key_states.shape[-2] - + if len(past_key_value.key_cache) <= self.layer_idx: past_key_value.key_cache.append(key_states) past_key_value.value_cache.append(value_states) @@ -150,20 +305,15 @@ def qwen2_attention_forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, " - f"but is {attn_weights.size()}" - ) + invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len), + ("Attention weights should be of size " + f"{(bsz, self.num_heads, q_len, kv_seq_len)}," + "but is {attn_weights.size()}")) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " - f"but is {attention_mask.size()}" - ) + invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), + (f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}" + f" but is {attention_mask.size()}")) attn_weights = attn_weights + attention_mask @@ -175,12 +325,10 @@ def qwen2_attention_forward( training=self.training) attn_output = torch.matmul(attn_weights, value_states) - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - invalidInputError( - False, - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), + "`attn_output` should be of size " + f"{(bsz, self.num_heads, q_len, self.head_dim)}," + f" but is {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)