From 6c3e698bf1867f33443ba00de47d00e37f3f2161 Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Thu, 21 Dec 2023 10:11:37 +0800 Subject: [PATCH] mistral decoding_fast_path and fused mlp (#9714) * mistral decoding_fast_path and fused mlp * meet code review --- .../llm/src/bigdl/llm/transformers/convert.py | 3 + .../bigdl/llm/transformers/models/mistral.py | 142 +++++++++++------- 2 files changed, 92 insertions(+), 53 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 62b74ae3..bc53d4ee 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -662,6 +662,9 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.MistralRMSNorm, llama_rms_norm_forward) + convert_forward(model, + module.MistralMLP, + llama_mlp_forward) elif model.config.model_type == "Yi": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/bigdl/llm/transformers/models/mistral.py b/python/llm/src/bigdl/llm/transformers/models/mistral.py index 9a9618bf..977f7817 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mistral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mistral.py @@ -44,7 +44,8 @@ from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ apply_rotary_pos_emb_no_cache_xpu - +from bigdl.llm.transformers.models.llama import is_enough_kv_cache_room +from bigdl.llm.transformers.low_bit_linear import SYM_INT4 KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -63,6 +64,17 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def should_use_fuse_rope(self, hidden_states, position_ids): + use_fuse_rope = hidden_states.device.type == "xpu" + use_fuse_rope = use_fuse_rope and not (self.training and hidden_states.requires_grad) + use_fuse_rope = use_fuse_rope and position_ids is not None + return use_fuse_rope + + +def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs): + return q_type == SYM_INT4 and use_fuse_rope and enough_kv_room and bs == 1 + + def mistral_attention_forward( self, hidden_states: torch.Tensor, @@ -76,64 +88,88 @@ def mistral_attention_forward( bsz, q_len, _ = hidden_states.size() device = hidden_states.device - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + enough_kv_room = is_enough_kv_cache_room(past_key_value) + decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype, + use_fuse_rope, + enough_kv_room, + bsz * q_len) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "mistral") - else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "mistral") - - if past_key_value is not None: - # reuse k, v, self_attention + if decoding_fast_path: + hidden_states = hidden_states.view(1, -1) + kv_seq_len = past_key_value[0].shape[-2] cache_k = past_key_value[0] cache_v = past_key_value[1] - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): - # allocate new - new_cache_k, new_cache_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) + import linear_q4_0 + query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + position_ids, + cache_k, cache_v, + self.q_proj.weight.qtype, + kv_seq_len, + self.head_dim) + kv_seq_len += 1 + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - new_cache_k[:] = cache_k - new_cache_v[:] = cache_v - cache_k = new_cache_k - cache_v = new_cache_v + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) - key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] - elif use_cache: - max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = init_kv_cache(bsz, - self.num_key_value_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key_states.dtype, - device=device) - new_key_states[:] = key_states - new_value_states[:] = value_states - key_states = new_key_states - value_states = new_value_states + if use_fuse_rope: + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "mistral") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "mistral") + + if past_key_value is not None: + # reuse k, v, self_attention + cache_k = past_key_value[0] + cache_v = past_key_value[1] + if not enough_kv_room: + # allocate new + new_cache_k, new_cache_v = extend_kv_cache(bsz, + self.num_key_value_heads, # Support GQA + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) + + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + cache_k = new_cache_k + cache_v = new_cache_v + + key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) + + elif use_cache: + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + new_key_states, new_value_states = init_kv_cache(bsz, + self.num_key_value_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) + new_key_states[:] = key_states + new_value_states[:] = value_states + key_states = new_key_states + value_states = new_value_states past_key_value = (key_states, value_states) if use_cache else None