add quantize kv cache support for qwen2 (#10134)
This commit is contained in:
parent
3f79128ed7
commit
d848efe17c
3 changed files with 230 additions and 22 deletions
|
|
@ -893,10 +893,14 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
# for Qwen1.5-7B
|
# for Qwen1.5-7B
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
from bigdl.llm.transformers.models.qwen2 import qwen2_model_forward
|
||||||
from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward
|
from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward
|
||||||
# TODO: add these optimization back
|
# TODO: add these optimization back
|
||||||
# RMSNorm and rotray embedding are disabled for now
|
# RMSNorm and rotray embedding are disabled for now
|
||||||
# as they lead to obvious performance drop for Qwen 1.5
|
# as they lead to obvious performance drop for Qwen 1.5
|
||||||
|
convert_forward(model,
|
||||||
|
module.Qwen2Model,
|
||||||
|
qwen2_model_forward)
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.Qwen2Attention,
|
module.Qwen2Attention,
|
||||||
qwen2_attention_forward
|
qwen2_attention_forward
|
||||||
|
|
|
||||||
56
python/llm/src/bigdl/llm/transformers/kv.py
Normal file
56
python/llm/src/bigdl/llm/transformers/kv.py
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .models.utils import init_fp8_kv_cache, append_fp8_kv_cache
|
||||||
|
from typing import Optional, Dict, Tuple, Any
|
||||||
|
from transformers.cache_utils import DynamicCache
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicFp8Cache(DynamicCache):
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
batch_size, num_heads, seq_len, head_dim = key_states.shape
|
||||||
|
|
||||||
|
if layer_idx == 0:
|
||||||
|
self.seen_tokens += seq_len
|
||||||
|
|
||||||
|
# Update the cache
|
||||||
|
if len(self.key_cache) <= layer_idx:
|
||||||
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
|
batch_size, num_heads, seq_len, head_dim,
|
||||||
|
device=key_states.device,
|
||||||
|
)
|
||||||
|
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
|
||||||
|
|
||||||
|
self.key_cache.append(k_cache)
|
||||||
|
self.value_cache.append(v_cache)
|
||||||
|
else:
|
||||||
|
k_cache = self.key_cache[layer_idx]
|
||||||
|
v_cache = self.value_cache[layer_idx]
|
||||||
|
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
|
||||||
|
self.key_cache[layer_idx] = k_cache
|
||||||
|
self.value_cache[layer_idx] = v_cache
|
||||||
|
|
||||||
|
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||||
|
|
@ -46,9 +46,11 @@ import torch.nn as nn
|
||||||
|
|
||||||
from bigdl.llm.transformers.models.llama import repeat_kv
|
from bigdl.llm.transformers.models.llama import repeat_kv
|
||||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
|
from bigdl.llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
apply_rotary_pos_emb_no_cache_xpu, is_enough_kv_cache_room_4_36
|
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
||||||
|
from bigdl.llm.transformers.kv import DynamicFp8Cache
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -61,6 +63,36 @@ def should_use_fuse_rope(self, query_states, position_ids):
|
||||||
return use_fuse_rope
|
return use_fuse_rope
|
||||||
|
|
||||||
|
|
||||||
|
def qwen2_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
|
||||||
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
return Qwen2Model.forward(
|
||||||
|
self=self,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def qwen2_attention_forward(
|
def qwen2_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
@ -71,6 +103,128 @@ def qwen2_attention_forward(
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||||
|
forward_function = qwen2_attention_forward_quantized
|
||||||
|
else:
|
||||||
|
forward_function = qwen2_attention_forward_origin
|
||||||
|
return forward_function(
|
||||||
|
self=self,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def qwen2_attention_forward_quantized(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[DynamicFp8Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||||
|
"Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len,
|
||||||
|
self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len,
|
||||||
|
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len,
|
||||||
|
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
invalidInputError(self.layer_idx is not None,
|
||||||
|
"The cache structure has changed since version v4.36. "
|
||||||
|
f"If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, "
|
||||||
|
"please make sure to initialize the attention class "
|
||||||
|
"with a layer index.")
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
|
cos, sin, position_ids)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
|
self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
if q_len != 1:
|
||||||
|
key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
|
||||||
|
attn_weights = torch.matmul(query_states, key.transpose(2, 3))
|
||||||
|
else:
|
||||||
|
import linear_q4_0
|
||||||
|
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
|
||||||
|
|
||||||
|
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
|
||||||
|
("Attention weights should be of size "
|
||||||
|
f"{(bsz, self.num_heads, q_len, kv_seq_len)},"
|
||||||
|
"but is {attn_weights.size()}"))
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
||||||
|
(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}"
|
||||||
|
f" but is {attention_mask.size()}"))
|
||||||
|
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||||
|
training=self.training)
|
||||||
|
|
||||||
|
if q_len != 1:
|
||||||
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
|
else:
|
||||||
|
import linear_q4_0
|
||||||
|
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||||
|
value_states.transpose(-1, -2))
|
||||||
|
|
||||||
|
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
|
||||||
|
"`attn_output` should be of size "
|
||||||
|
f"{(bsz, self.num_heads, q_len, self.head_dim)},"
|
||||||
|
f" but is {attn_output.size()}")
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def qwen2_attention_forward_origin(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
|
|
||||||
|
|
@ -106,7 +260,8 @@ def qwen2_attention_forward(
|
||||||
)
|
)
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
|
cos, sin, position_ids)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# update the number of seen tokens
|
# update the number of seen tokens
|
||||||
|
|
@ -150,20 +305,15 @@ def qwen2_attention_forward(
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
|
||||||
invalidInputError(
|
("Attention weights should be of size "
|
||||||
False,
|
f"{(bsz, self.num_heads, q_len, kv_seq_len)},"
|
||||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, "
|
"but is {attn_weights.size()}"))
|
||||||
f"but is {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
||||||
invalidInputError(
|
(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}"
|
||||||
False,
|
f" but is {attention_mask.size()}"))
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
|
|
||||||
f"but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
|
@ -175,12 +325,10 @@ def qwen2_attention_forward(
|
||||||
training=self.training)
|
training=self.training)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
|
||||||
invalidInputError(
|
"`attn_output` should be of size "
|
||||||
False,
|
f"{(bsz, self.num_heads, q_len, self.head_dim)},"
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
f" but is {attn_output.size()}")
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue