[LLM] Optimize kv_cache for mistral model family (#9189)

* add kv_cache optimization for mistral model

* kv_cache optimize for mistral

* update stylr

* update
This commit is contained in:
SONG Ge 2023-10-18 15:13:37 +08:00 committed by GitHub
parent 3555ebc148
commit 0765f94770
2 changed files with 34 additions and 3 deletions

View file

@ -41,10 +41,14 @@ from typing import Optional, Tuple
import torch
from torch import nn
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
apply_rotary_pos_emb_no_cache_xpu
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
@ -70,6 +74,7 @@ def mistral_attention_forward(
padding_mask: Optional[torch.Tensor]=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
@ -84,6 +89,7 @@ def mistral_attention_forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
@ -96,8 +102,33 @@ def mistral_attention_forward(
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(bsz,
self.num_key_value_heads,
self.head_dim,
kv_seq_len,
max_cache_length,
dtype=key_states.dtype,
device=device)
new_key_states[:] = key_states
new_value_states[:] = value_states
key_states = new_key_states
value_states = new_value_states
past_key_value = (key_states, value_states) if use_cache else None

View file

@ -71,7 +71,7 @@ def rotate_every_two(x):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox"]:
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral"]:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]