diff --git a/python/llm/src/bigdl/llm/transformers/models/mistral.py b/python/llm/src/bigdl/llm/transformers/models/mistral.py index 36445906..847f43b9 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mistral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mistral.py @@ -41,10 +41,14 @@ from typing import Optional, Tuple import torch from torch import nn from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ apply_rotary_pos_emb_no_cache_xpu +KV_CACHE_ALLOC_BLOCK_LENGTH = 256 + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). @@ -70,6 +74,7 @@ def mistral_attention_forward( padding_mask: Optional[torch.Tensor]=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() + device = hidden_states.device query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -84,6 +89,7 @@ def mistral_attention_forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] + if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, key_states, @@ -96,8 +102,33 @@ def mistral_attention_forward( if past_key_value is not None: # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + cache_k = past_key_value[0] + cache_v = past_key_value[1] + if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + # allocate new + new_cache_k, new_cache_v = extend_kv_cache(bsz, + self.num_key_value_heads, # Support GQA + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) + + key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) + + elif use_cache: + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + new_key_states, new_value_states = init_kv_cache(bsz, + self.num_key_value_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) + new_key_states[:] = key_states + new_value_states[:] = value_states + key_states = new_key_states + value_states = new_value_states past_key_value = (key_states, value_states) if use_cache else None diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index a0367c4e..b5888319 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -71,7 +71,7 @@ def rotate_every_two(x): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family): - if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox"]: + if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral"]: # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]