add quantize kv cache support for qwen2 (#10134)
This commit is contained in:
		
							parent
							
								
									3f79128ed7
								
							
						
					
					
						commit
						d848efe17c
					
				
					 3 changed files with 230 additions and 22 deletions
				
			
		| 
						 | 
				
			
			@ -893,10 +893,14 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        # for Qwen1.5-7B
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from bigdl.llm.transformers.models.qwen2 import qwen2_model_forward
 | 
			
		||||
        from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward
 | 
			
		||||
        # TODO: add these optimization back
 | 
			
		||||
        # RMSNorm and rotray embedding are disabled for now
 | 
			
		||||
        # as they lead to obvious performance drop for Qwen 1.5
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.Qwen2Model,
 | 
			
		||||
                        qwen2_model_forward)
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.Qwen2Attention,
 | 
			
		||||
                        qwen2_attention_forward
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										56
									
								
								python/llm/src/bigdl/llm/transformers/kv.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								python/llm/src/bigdl/llm/transformers/kv.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,56 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from .models.utils import init_fp8_kv_cache, append_fp8_kv_cache
 | 
			
		||||
from typing import Optional, Dict, Tuple, Any
 | 
			
		||||
from transformers.cache_utils import DynamicCache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DynamicFp8Cache(DynamicCache):
 | 
			
		||||
    def update(
 | 
			
		||||
        self,
 | 
			
		||||
        key_states: torch.Tensor,
 | 
			
		||||
        value_states: torch.Tensor,
 | 
			
		||||
        layer_idx: int,
 | 
			
		||||
        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
 | 
			
		||||
        batch_size, num_heads, seq_len, head_dim = key_states.shape
 | 
			
		||||
 | 
			
		||||
        if layer_idx == 0:
 | 
			
		||||
            self.seen_tokens += seq_len
 | 
			
		||||
 | 
			
		||||
        # Update the cache
 | 
			
		||||
        if len(self.key_cache) <= layer_idx:
 | 
			
		||||
            k_cache, v_cache = init_fp8_kv_cache(
 | 
			
		||||
                batch_size, num_heads, seq_len, head_dim,
 | 
			
		||||
                device=key_states.device,
 | 
			
		||||
            )
 | 
			
		||||
            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
			
		||||
 | 
			
		||||
            self.key_cache.append(k_cache)
 | 
			
		||||
            self.value_cache.append(v_cache)
 | 
			
		||||
        else:
 | 
			
		||||
            k_cache = self.key_cache[layer_idx]
 | 
			
		||||
            v_cache = self.value_cache[layer_idx]
 | 
			
		||||
            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
			
		||||
            self.key_cache[layer_idx] = k_cache
 | 
			
		||||
            self.value_cache[layer_idx] = v_cache
 | 
			
		||||
 | 
			
		||||
        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
			
		||||
| 
						 | 
				
			
			@ -46,9 +46,11 @@ import torch.nn as nn
 | 
			
		|||
 | 
			
		||||
from bigdl.llm.transformers.models.llama import repeat_kv
 | 
			
		||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
 | 
			
		||||
    apply_rotary_pos_emb_no_cache_xpu, is_enough_kv_cache_room_4_36
 | 
			
		||||
from bigdl.llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36
 | 
			
		||||
from bigdl.llm.transformers.kv import DynamicFp8Cache
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
| 
						 | 
				
			
			@ -61,6 +63,36 @@ def should_use_fuse_rope(self, query_states, position_ids):
 | 
			
		|||
    return use_fuse_rope
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen2_model_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids: torch.LongTensor = None,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    use_cache: Optional[bool] = None,
 | 
			
		||||
    output_attentions: Optional[bool] = None,
 | 
			
		||||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
):
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
 | 
			
		||||
        if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
    return Qwen2Model.forward(
 | 
			
		||||
        self=self,
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
        position_ids=position_ids,
 | 
			
		||||
        past_key_values=past_key_values,
 | 
			
		||||
        inputs_embeds=inputs_embeds,
 | 
			
		||||
        use_cache=use_cache,
 | 
			
		||||
        output_attentions=output_attentions,
 | 
			
		||||
        output_hidden_states=output_hidden_states,
 | 
			
		||||
        return_dict=return_dict,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen2_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			@ -71,6 +103,128 @@ def qwen2_attention_forward(
 | 
			
		|||
    use_cache: bool = False,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
    if use_quantize_kv_cache(self.q_proj, hidden_states):
 | 
			
		||||
        forward_function = qwen2_attention_forward_quantized
 | 
			
		||||
    else:
 | 
			
		||||
        forward_function = qwen2_attention_forward_origin
 | 
			
		||||
    return forward_function(
 | 
			
		||||
        self=self,
 | 
			
		||||
        hidden_states=hidden_states,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
        position_ids=position_ids,
 | 
			
		||||
        past_key_value=past_key_value,
 | 
			
		||||
        output_attentions=output_attentions,
 | 
			
		||||
        use_cache=use_cache,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen2_attention_forward_quantized(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_value: Optional[DynamicFp8Cache] = None,
 | 
			
		||||
    output_attentions: bool = False,
 | 
			
		||||
    use_cache: bool = False,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
    if "padding_mask" in kwargs:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
            "Passing `padding_mask` is deprecated and will be removed in v4.37. "
 | 
			
		||||
            "Please make sure use `attention_mask` instead.`"
 | 
			
		||||
        )
 | 
			
		||||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
 | 
			
		||||
    query_states = self.q_proj(hidden_states)
 | 
			
		||||
    key_states = self.k_proj(hidden_states)
 | 
			
		||||
    value_states = self.v_proj(hidden_states)
 | 
			
		||||
 | 
			
		||||
    query_states = query_states.view(bsz, q_len,
 | 
			
		||||
                                     self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    key_states = key_states.view(bsz, q_len,
 | 
			
		||||
                                 self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    value_states = value_states.view(bsz, q_len,
 | 
			
		||||
                                     self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        invalidInputError(self.layer_idx is not None,
 | 
			
		||||
                          "The cache structure has changed since version v4.36. "
 | 
			
		||||
                          f"If you are using {self.__class__.__name__} "
 | 
			
		||||
                          "for auto-regressive decoding with k/v caching, "
 | 
			
		||||
                          "please make sure to initialize the attention class "
 | 
			
		||||
                          "with a layer index.")
 | 
			
		||||
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
			
		||||
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
                                                    cos, sin, position_ids)
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
 | 
			
		||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                         self.layer_idx, cache_kwargs)
 | 
			
		||||
 | 
			
		||||
    if q_len != 1:
 | 
			
		||||
        key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
 | 
			
		||||
        attn_weights = torch.matmul(query_states, key.transpose(2, 3))
 | 
			
		||||
    else:
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
 | 
			
		||||
 | 
			
		||||
    attn_weights = attn_weights / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
    invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
 | 
			
		||||
                      ("Attention weights should be of size "
 | 
			
		||||
                       f"{(bsz, self.num_heads, q_len, kv_seq_len)},"
 | 
			
		||||
                       "but is {attn_weights.size()}"))
 | 
			
		||||
 | 
			
		||||
    if attention_mask is not None:
 | 
			
		||||
        invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
 | 
			
		||||
                          (f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}"
 | 
			
		||||
                           f" but is {attention_mask.size()}"))
 | 
			
		||||
 | 
			
		||||
        attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
    # upcast attention to fp32
 | 
			
		||||
    attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
			
		||||
                                         dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
    attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
 | 
			
		||||
                                         training=self.training)
 | 
			
		||||
 | 
			
		||||
    if q_len != 1:
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value)
 | 
			
		||||
    else:
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
 | 
			
		||||
                                                        value_states.transpose(-1, -2))
 | 
			
		||||
 | 
			
		||||
    invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
 | 
			
		||||
                      "`attn_output` should be of size "
 | 
			
		||||
                      f"{(bsz, self.num_heads, q_len, self.head_dim)},"
 | 
			
		||||
                      f" but is {attn_output.size()}")
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
 | 
			
		||||
    attn_output = self.o_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
    if not output_attentions:
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
 | 
			
		||||
    return attn_output, attn_weights, past_key_value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen2_attention_forward_origin(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
    output_attentions: bool = False,
 | 
			
		||||
    use_cache: bool = False,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -106,13 +260,14 @@ def qwen2_attention_forward(
 | 
			
		|||
            )
 | 
			
		||||
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
			
		||||
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 | 
			
		||||
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
                                                    cos, sin, position_ids)
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        # update the number of seen tokens
 | 
			
		||||
        if self.layer_idx == 0:
 | 
			
		||||
            past_key_value.seen_tokens += key_states.shape[-2]
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        if len(past_key_value.key_cache) <= self.layer_idx:
 | 
			
		||||
            past_key_value.key_cache.append(key_states)
 | 
			
		||||
            past_key_value.value_cache.append(value_states)
 | 
			
		||||
| 
						 | 
				
			
			@ -150,20 +305,15 @@ def qwen2_attention_forward(
 | 
			
		|||
 | 
			
		||||
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 | 
			
		||||
        invalidInputError(
 | 
			
		||||
            False,
 | 
			
		||||
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, "
 | 
			
		||||
            f"but is {attn_weights.size()}"
 | 
			
		||||
        )
 | 
			
		||||
    invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
 | 
			
		||||
                      ("Attention weights should be of size "
 | 
			
		||||
                       f"{(bsz, self.num_heads, q_len, kv_seq_len)},"
 | 
			
		||||
                       "but is {attn_weights.size()}"))
 | 
			
		||||
 | 
			
		||||
    if attention_mask is not None:
 | 
			
		||||
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
 | 
			
		||||
            invalidInputError(
 | 
			
		||||
                False,
 | 
			
		||||
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
 | 
			
		||||
                f"but is {attention_mask.size()}"
 | 
			
		||||
            )
 | 
			
		||||
        invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
 | 
			
		||||
                          (f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}"
 | 
			
		||||
                           f" but is {attention_mask.size()}"))
 | 
			
		||||
 | 
			
		||||
        attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -175,12 +325,10 @@ def qwen2_attention_forward(
 | 
			
		|||
                                         training=self.training)
 | 
			
		||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 | 
			
		||||
        invalidInputError(
 | 
			
		||||
            False,
 | 
			
		||||
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
 | 
			
		||||
            f" {attn_output.size()}"
 | 
			
		||||
        )
 | 
			
		||||
    invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
 | 
			
		||||
                      "`attn_output` should be of size "
 | 
			
		||||
                      f"{(bsz, self.num_heads, q_len, self.head_dim)},"
 | 
			
		||||
                      f" but is {attn_output.size()}")
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue