Support mixtral attention optimization on transformers-v4.36.0 (#9674)

* add example code to support mistral/mixtral attention on transformers v4.36.0

* update

* style fix

* add update for seen-tokens

* support mixtral

* rm mistral change

* small fix

* add more comments and remove use_cache part

---------

Co-authored-by: plusbang <binbin1.deng@intel.com>
This commit is contained in:
SONG Ge 2023-12-15 14:30:23 +08:00 committed by GitHub
parent adbef56001
commit d5b81af7bd
3 changed files with 153 additions and 3 deletions

View file

@ -620,7 +620,11 @@ def _optimize_post(model, lightweight_bmm=False):
"to run Mixtral models.")
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.mixtral import mixtral_moeblock_forward
from bigdl.llm.transformers.models.mixtral import mixtral_moeblock_forward, \
mixtral_attention_forward
convert_forward(model,
module.MixtralAttention,
mixtral_attention_forward)
convert_forward(model,
module.MixtralRMSNorm,
llama_rms_norm_forward)

View file

@ -44,6 +44,26 @@ import torch
from torch import nn
import torch.nn.functional as F
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
apply_rotary_pos_emb_no_cache_xpu
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def mixtral_moeblock_forward(self,
@ -106,3 +126,127 @@ def mixtral_moeblock_forward(self,
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
def mixtral_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None,
position_ids: Optional[torch.LongTensor]=None,
past_key_value: Optional[Tuple[torch.Tensor]]=None,
output_attentions: bool=False,
use_cache: bool=False,
padding_mask: Optional[torch.Tensor]=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(False, "The cache structure has changed since version v4.36. "
f"If you are using {self.__class__.__name__} for "
"auto-regressive decodingwith k/v caching, please make sure "
"to initialize the attention class with a layer index.")
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"mixtral")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "mixtral")
if past_key_value is not None:
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value.seen_tokens += key_states.shape[-2]
# reuse k, v, self_attention
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(
query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},"
f" but is {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.\
softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)},"
f" but is {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value

View file

@ -71,7 +71,8 @@ def rotate_every_two(x):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral"]:
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
"mixtral"]:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
@ -98,7 +99,8 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
import linear_q4_0
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral"]:
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
"mixtral"]:
linear_q4_0.apply_rotary_embedding_half_qk(q, k, position_ids, q_embed, k_embed)
return q_embed, k_embed
else: