LLM: support quantized kv cache for Mistral in transformers >=4.36.0 (#10326)

* support quantize kv for mistral in transformers 4.36

* update mistral support.

* fix style.
This commit is contained in:
Cengguang Zhang 2024-03-05 16:23:50 +08:00 committed by GitHub
parent 566e9bbb36
commit 30d009bca7
3 changed files with 266 additions and 12 deletions

View file

@ -1092,10 +1092,15 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.mistral import mistral_attention_forward_4_36 from bigdl.llm.transformers.models.mistral import mistral_attention_forward_4_36
from bigdl.llm.transformers.models.mistral import mistral_model_forward_4_36
convert_forward(model, convert_forward(model,
module.MistralAttention, module.MistralAttention,
mistral_attention_forward_4_36 mistral_attention_forward_4_36
) )
convert_forward(model,
module.MistralModel,
mistral_model_forward_4_36
)
convert_forward(model, convert_forward(model,
module.MistralRMSNorm, module.MistralRMSNorm,
llama_rms_norm_forward) llama_rms_norm_forward)

View file

@ -53,6 +53,10 @@ from transformers.models.llama.modeling_llama import LlamaModel
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
try:
from transformers.cache_utils import Cache
except ImportError:
Cache = Tuple[torch.Tensor]
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -934,11 +938,11 @@ def llama_attention_forward_4_36(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
**kwargs **kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if use_quantize_kv_cache(self.q_proj, hidden_states): if use_quantize_kv_cache(self.q_proj, hidden_states):
forward_function = llama_attention_forward_4_36_quantized forward_function = llama_attention_forward_4_36_quantized
else: else:
@ -960,11 +964,11 @@ def llama_attention_forward_4_36_quantized(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
**kwargs **kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if "padding_mask" in kwargs: if "padding_mask" in kwargs:
warnings.warn( warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. " "Passing `padding_mask` is deprecated and will be removed in v4.37. "
@ -999,8 +1003,10 @@ def llama_attention_forward_4_36_quantized(
position_ids, position_ids,
tmp_cache_k, tmp_cache_v, tmp_cache_k, tmp_cache_v,
self.q_proj.weight.qtype, self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
0, 0,
self.head_dim) self.head_dim,
self.rotary_emb.base,)
else: else:
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
@ -1140,11 +1146,11 @@ def llama_attention_forward_4_36_original(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
**kwargs **kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if "padding_mask" in kwargs: if "padding_mask" in kwargs:
warnings.warn( warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. " "Passing `padding_mask` is deprecated and will be removed in v4.37. "

View file

@ -36,11 +36,13 @@
# limitations under the License. # limitations under the License.
""" PyTorch Mistral model.""" """ PyTorch Mistral model."""
import math import math
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import MistralModel
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
@ -51,7 +53,10 @@ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
is_enough_kv_cache_room_4_36 is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
try:
from transformers.cache_utils import Cache
except ImportError:
Cache = Tuple[torch.Tensor]
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -121,6 +126,37 @@ def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_
return attn_output, attn_weights return attn_output, attn_weights
def mistral_model_forward_4_36(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from bigdl.llm.transformers.kv import DynamicFp8Cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
return MistralModel.forward(
self=self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def mistral_attention_forward( def mistral_attention_forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -480,11 +516,218 @@ def mistral_attention_forward_4_36(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None,
position_ids: Optional[torch.LongTensor]=None, position_ids: Optional[torch.LongTensor]=None,
past_key_value: Optional[Tuple[torch.Tensor]]=None, past_key_value: Optional[Cache]=None,
output_attentions: bool=False, output_attentions: bool=False,
use_cache: bool=False, use_cache: bool=False,
padding_mask: Optional[torch.Tensor]=None, **kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
forward_function = mistral_attention_forward_4_36_quantized
else:
forward_function = mistral_attention_forward_4_36_original
return forward_function(
self=self,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
kwargs=kwargs
)
def mistral_attention_forward_4_36_quantized(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None,
position_ids: Optional[torch.LongTensor]=None,
past_key_value: Optional[Cache]=None,
output_attentions: bool=False,
use_cache: bool=False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
bsz, q_len, hidden_size = hidden_states.size()
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
tmp_cache_k, tmp_cache_v = init_kv_cache(
bsz,
self.num_key_value_heads,
self.head_dim,
0,
1,
dtype=hidden_states.dtype,
device=device
)
import linear_q4_0
query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
self.q_proj.weight,
self.k_proj.weight,
self.v_proj.weight,
position_ids,
tmp_cache_k, tmp_cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
0,
self.head_dim)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(
False,
f"The cache structure has changed since version v4.36. "
"If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, "
"please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"mistral")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "mistral")
if not self.training and not hidden_states.requires_grad:
fsdp_flag = use_flash_attention(query_states, key_states)
else:
fsdp_flag = False
if fsdp_flag:
attention_dtype = torch.float16 # use fp16 for flash attention
else:
attention_dtype = original_dtype
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
kv_seq_len = key_states.shape[-2]
if len(past_key_value.key_cache) <= self.layer_idx:
attn_weights = torch.matmul(query_states.to(key_states.dtype),
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if use_cache:
cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
else:
cache_kwargs = None # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
kv_seq_len = key_states.shape[-2]
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
attn_output = torch.matmul(attn_weights, value_states)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size:
invalidInputError(False,
f"`attn_output` should be of size {attn_output_size},"
f" but is {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output.to(original_dtype), attn_weights, past_key_value
def mistral_attention_forward_4_36_original(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None,
position_ids: Optional[torch.LongTensor]=None,
past_key_value: Optional[Cache]=None,
output_attentions: bool=False,
use_cache: bool=False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
bsz, q_len, hidden_size = hidden_states.size() bsz, q_len, hidden_size = hidden_states.size()
device = hidden_states.device device = hidden_states.device
# for flash attention # for flash attention