From 30d009bca7a3b9f2ba25b5113d4a13c12269782c Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Tue, 5 Mar 2024 16:23:50 +0800 Subject: [PATCH] LLM: support quantized kv cache for Mistral in transformers >=4.36.0 (#10326) * support quantize kv for mistral in transformers 4.36 * update mistral support. * fix style. --- .../llm/src/bigdl/llm/transformers/convert.py | 5 + .../bigdl/llm/transformers/models/llama.py | 20 +- .../bigdl/llm/transformers/models/mistral.py | 253 +++++++++++++++++- 3 files changed, 266 insertions(+), 12 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 7e99e007..d99f46ec 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -1092,10 +1092,15 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.mistral import mistral_attention_forward_4_36 + from bigdl.llm.transformers.models.mistral import mistral_model_forward_4_36 convert_forward(model, module.MistralAttention, mistral_attention_forward_4_36 ) + convert_forward(model, + module.MistralModel, + mistral_model_forward_4_36 + ) convert_forward(model, module.MistralRMSNorm, llama_rms_norm_forward) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 29d6fe35..e8b5e8e4 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -53,6 +53,10 @@ from transformers.models.llama.modeling_llama import LlamaModel from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError +try: + from transformers.cache_utils import Cache +except ImportError: + Cache = Tuple[torch.Tensor] def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -934,11 +938,11 @@ def llama_attention_forward_4_36( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: if use_quantize_kv_cache(self.q_proj, hidden_states): forward_function = llama_attention_forward_4_36_quantized else: @@ -960,11 +964,11 @@ def llama_attention_forward_4_36_quantized( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. " @@ -999,8 +1003,10 @@ def llama_attention_forward_4_36_quantized( position_ids, tmp_cache_k, tmp_cache_v, self.q_proj.weight.qtype, + self.v_proj.weight.qtype, 0, - self.head_dim) + self.head_dim, + self.rotary_emb.base,) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -1140,11 +1146,11 @@ def llama_attention_forward_4_36_original( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. " diff --git a/python/llm/src/bigdl/llm/transformers/models/mistral.py b/python/llm/src/bigdl/llm/transformers/models/mistral.py index b79d053a..ce8847a2 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mistral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mistral.py @@ -36,11 +36,13 @@ # limitations under the License. """ PyTorch Mistral model.""" import math -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from torch import nn import torch.nn.functional as F +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.mistral.modeling_mistral import MistralModel from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ @@ -51,7 +53,10 @@ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ is_enough_kv_cache_room_4_36 from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp - +try: + from transformers.cache_utils import Cache +except ImportError: + Cache = Tuple[torch.Tensor] KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -121,6 +126,37 @@ def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_ return attn_output, attn_weights +def mistral_model_forward_4_36( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + from bigdl.llm.transformers.kv import DynamicFp8Cache + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids): + if not isinstance(past_key_values, DynamicFp8Cache): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + return MistralModel.forward( + self=self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + def mistral_attention_forward( self, hidden_states: torch.Tensor, @@ -480,11 +516,218 @@ def mistral_attention_forward_4_36( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor]=None, position_ids: Optional[torch.LongTensor]=None, - past_key_value: Optional[Tuple[torch.Tensor]]=None, + past_key_value: Optional[Cache]=None, output_attentions: bool=False, use_cache: bool=False, - padding_mask: Optional[torch.Tensor]=None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + **kwargs +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if use_quantize_kv_cache(self.q_proj, hidden_states): + forward_function = mistral_attention_forward_4_36_quantized + else: + forward_function = mistral_attention_forward_4_36_original + return forward_function( + self=self, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + kwargs=kwargs + ) + + +def mistral_attention_forward_4_36_quantized( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor]=None, + position_ids: Optional[torch.LongTensor]=None, + past_key_value: Optional[Cache]=None, + output_attentions: bool=False, + use_cache: bool=False, + **kwargs +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + bsz, q_len, hidden_size = hidden_states.size() + device = hidden_states.device + # for flash attention + original_dtype = hidden_states.dtype + + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) + decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype, + use_fuse_rope, + enough_kv_room, + bsz * q_len) + + if decoding_fast_path: + hidden_states = hidden_states.view(1, -1) + tmp_cache_k, tmp_cache_v = init_kv_cache( + bsz, + self.num_key_value_heads, + self.head_dim, + 0, + 1, + dtype=hidden_states.dtype, + device=device + ) + import linear_q4_0 + query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + position_ids, + tmp_cache_k, tmp_cache_v, + self.q_proj.weight.qtype, + self.v_proj.weight.qtype, + 0, + self.head_dim) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + invalidInputError( + False, + f"The cache structure has changed since version v4.36. " + "If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, " + "please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + if use_fuse_rope: + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "mistral") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "mistral") + + if not self.training and not hidden_states.requires_grad: + fsdp_flag = use_flash_attention(query_states, key_states) + else: + fsdp_flag = False + if fsdp_flag: + attention_dtype = torch.float16 # use fp16 for flash attention + else: + attention_dtype = original_dtype + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + kv_seq_len = key_states.shape[-2] + if len(past_key_value.key_cache) <= self.layer_idx: + attn_weights = torch.matmul(query_states.to(key_states.dtype), + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention weights should be of size " + f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + if use_cache: + cache_kwargs = None + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) + else: + cache_kwargs = None # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) + kv_seq_len = key_states.shape[-2] + if query_states.size(2) != 1 or query_states.device.type != 'xpu': + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + else: + import linear_q4_0 + attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states) + + attn_weights = attn_weights / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention weights should be of size " + f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + invalidInputError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + + if query_states.size(2) != 1 or query_states.device.type != 'xpu': + attn_output = torch.matmul(attn_weights, value_states) + else: + import linear_q4_0 + attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, + value_states.transpose(-1, -2)) + + attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) + if attn_output.size() != attn_output_size: + invalidInputError(False, + f"`attn_output` should be of size {attn_output_size}," + f" but is {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output.to(original_dtype), attn_weights, past_key_value + + +def mistral_attention_forward_4_36_original( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor]=None, + position_ids: Optional[torch.LongTensor]=None, + past_key_value: Optional[Cache]=None, + output_attentions: bool=False, + use_cache: bool=False, + **kwargs +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: bsz, q_len, hidden_size = hidden_states.size() device = hidden_states.device # for flash attention