From d5b81af7bdba42452f982292f704c266df4f03d0 Mon Sep 17 00:00:00 2001 From: SONG Ge <38711238+sgwhat@users.noreply.github.com> Date: Fri, 15 Dec 2023 14:30:23 +0800 Subject: [PATCH] Support mixtral attention optimization on transformers-v4.36.0 (#9674) * add example code to support mistral/mixtral attention on transformers v4.36.0 * update * style fix * add update for seen-tokens * support mixtral * rm mistral change * small fix * add more comments and remove use_cache part --------- Co-authored-by: plusbang --- .../llm/src/bigdl/llm/transformers/convert.py | 6 +- .../bigdl/llm/transformers/models/mixtral.py | 144 ++++++++++++++++++ .../bigdl/llm/transformers/models/utils.py | 6 +- 3 files changed, 153 insertions(+), 3 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 7201cc32..fbbd280b 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -620,7 +620,11 @@ def _optimize_post(model, lightweight_bmm=False): "to run Mixtral models.") modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) - from bigdl.llm.transformers.models.mixtral import mixtral_moeblock_forward + from bigdl.llm.transformers.models.mixtral import mixtral_moeblock_forward, \ + mixtral_attention_forward + convert_forward(model, + module.MixtralAttention, + mixtral_attention_forward) convert_forward(model, module.MixtralRMSNorm, llama_rms_norm_forward) diff --git a/python/llm/src/bigdl/llm/transformers/models/mixtral.py b/python/llm/src/bigdl/llm/transformers/models/mixtral.py index 2ef29261..fda47df5 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mixtral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mixtral.py @@ -44,6 +44,26 @@ import torch from torch import nn import torch.nn.functional as F from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ + apply_rotary_pos_emb_no_cache_xpu + + +KV_CACHE_ALLOC_BLOCK_LENGTH = 256 + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) + to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def mixtral_moeblock_forward(self, @@ -106,3 +126,127 @@ def mixtral_moeblock_forward(self, final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits + + +def mixtral_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor]=None, + position_ids: Optional[torch.LongTensor]=None, + past_key_value: Optional[Tuple[torch.Tensor]]=None, + output_attentions: bool=False, + use_cache: bool=False, + padding_mask: Optional[torch.Tensor]=None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + invalidInputError(False, "The cache structure has changed since version v4.36. " + f"If you are using {self.__class__.__name__} for " + "auto-regressive decodingwith k/v caching, please make sure " + "to initialize the attention class with a layer index.") + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "mixtral") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "mixtral") + + if past_key_value is not None: + # update the number of seen tokens + if self.layer_idx == 0: + past_key_value.seen_tokens += key_states.shape[-2] + + # reuse k, v, self_attention + # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` + if len(past_key_value.key_cache) <= self.layer_idx: + past_key_value.key_cache.append(key_states) + past_key_value.value_cache.append(value_states) + else: + cache_k = past_key_value.key_cache[self.layer_idx] + cache_v = past_key_value.value_cache[self.layer_idx] + + if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + # allocate new + new_cache_k, new_cache_v = extend_kv_cache(bsz, + self.num_key_value_heads, # Support GQA + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) + + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + cache_k = new_cache_k + cache_v = new_cache_v + + key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) + + # update past_key_value + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}," + f" but is {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.\ + softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + invalidInputError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}," + f" but is {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index b5888319..539972e8 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -71,7 +71,8 @@ def rotate_every_two(x): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family): - if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral"]: + if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral", + "mixtral"]: # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] @@ -98,7 +99,8 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family): import linear_q4_0 q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device) k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) - if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral"]: + if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral", + "mixtral"]: linear_q4_0.apply_rotary_embedding_half_qk(q, k, position_ids, q_embed, k_embed) return q_embed, k_embed else: