mistral decoding_fast_path and fused mlp (#9714)

* mistral decoding_fast_path and fused mlp

* meet code review
This commit is contained in:
Xin Qiu 2023-12-21 10:11:37 +08:00 committed by GitHub
parent d157f623b6
commit 6c3e698bf1
2 changed files with 92 additions and 53 deletions

View file

@ -662,6 +662,9 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.MistralRMSNorm,
llama_rms_norm_forward)
convert_forward(model,
module.MistralMLP,
llama_mlp_forward)
elif model.config.model_type == "Yi":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)

View file

@ -44,7 +44,8 @@ from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.models.llama import is_enough_kv_cache_room
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -63,6 +64,17 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def should_use_fuse_rope(self, hidden_states, position_ids):
use_fuse_rope = hidden_states.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and hidden_states.requires_grad)
use_fuse_rope = use_fuse_rope and position_ids is not None
return use_fuse_rope
def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
return q_type == SYM_INT4 and use_fuse_rope and enough_kv_room and bs == 1
def mistral_attention_forward(
self,
hidden_states: torch.Tensor,
@ -76,6 +88,30 @@ def mistral_attention_forward(
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room(past_key_value)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
kv_seq_len = past_key_value[0].shape[-2]
cache_k = past_key_value[0]
cache_v = past_key_value[1]
import linear_q4_0
query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
self.q_proj.weight,
self.k_proj.weight,
self.v_proj.weight,
position_ids,
cache_k, cache_v,
self.q_proj.weight.qtype,
kv_seq_len,
self.head_dim)
kv_seq_len += 1
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
@ -90,7 +126,7 @@ def mistral_attention_forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
@ -104,7 +140,7 @@ def mistral_attention_forward(
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if not enough_kv_room:
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA