From ba0b93957934ebb22c75d6872bf1c09664e1d865 Mon Sep 17 00:00:00 2001 From: SONG Ge <38711238+sgwhat@users.noreply.github.com> Date: Fri, 22 Dec 2023 09:59:27 +0800 Subject: [PATCH] [LLM] Support transformers-v4.36.0 on mistral model (#9744) * add support transformers-v4.36.0 on mistral model * python/llm/src/bigdl/llm/transformers/models/mistral.py * make the redundant implementation as utils * fix code style * fix * fix style * update with utils enough_kv_room --- .../llm/src/bigdl/llm/transformers/convert.py | 41 ++-- .../bigdl/llm/transformers/models/mistral.py | 212 +++++++++++++++--- 2 files changed, 205 insertions(+), 48 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index bc53d4ee..1c1f0b84 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -652,19 +652,34 @@ def _optimize_post(model, lightweight_bmm=False): module.MistralRMSNorm, llama_rms_norm_forward) else: - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) - from bigdl.llm.transformers.models.mistral import mistral_attention_forward - convert_forward(model, - module.MistralAttention, - mistral_attention_forward - ) - convert_forward(model, - module.MistralRMSNorm, - llama_rms_norm_forward) - convert_forward(model, - module.MistralMLP, - llama_mlp_forward) + if version.parse(trans_version) >= version.parse("4.36.0"): + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.mistral import mistral_attention_forward_4_36 + convert_forward(model, + module.MistralAttention, + mistral_attention_forward_4_36 + ) + convert_forward(model, + module.MistralRMSNorm, + llama_rms_norm_forward) + convert_forward(model, + module.MistralMLP, + llama_mlp_forward) + else: + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.mistral import mistral_attention_forward + convert_forward(model, + module.MistralAttention, + mistral_attention_forward + ) + convert_forward(model, + module.MistralRMSNorm, + llama_rms_norm_forward) + convert_forward(model, + module.MistralMLP, + llama_mlp_forward) elif model.config.model_type == "Yi": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/bigdl/llm/transformers/models/mistral.py b/python/llm/src/bigdl/llm/transformers/models/mistral.py index c0758dab..840d2dfb 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mistral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mistral.py @@ -44,7 +44,8 @@ from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ apply_rotary_pos_emb_no_cache_xpu -from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31 +from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\ + is_enough_kv_cache_room_4_36 from bigdl.llm.transformers.low_bit_linear import SYM_INT4 KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -75,6 +76,46 @@ def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs): return q_type == SYM_INT4 and use_fuse_rope and enough_kv_room and bs == 1 +def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len, + num_heads, head_dim, hidden_size, attention_mask): + attn_weights = torch.matmul( + query_states, + key_states.transpose(2, 3)) / math.sqrt(head_dim) + + if attn_weights.size() != (bsz, num_heads, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention weights should be of size {(bsz, num_heads, q_len, kv_seq_len)}," + f" but is {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.\ + softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, num_heads, q_len, head_dim): + invalidInputError( + f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}," + f" but is {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, hidden_size) + + return attn_output, attn_weights + + def mistral_attention_forward( self, hidden_states: torch.Tensor, @@ -177,40 +218,141 @@ def mistral_attention_forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul( - query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}," - f" but is {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.\ - softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - invalidInputError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}," - f" but is {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, value_states, + bsz, q_len, kv_seq_len, + self.num_heads, self.head_dim, + self.hidden_size, attention_mask) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def mistral_attention_forward_4_36( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor]=None, + position_ids: Optional[torch.LongTensor]=None, + past_key_value: Optional[Tuple[torch.Tensor]]=None, + output_attentions: bool=False, + use_cache: bool=False, + padding_mask: Optional[torch.Tensor]=None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) + decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype, + use_fuse_rope, + enough_kv_room, + bsz * q_len) + + if decoding_fast_path: + hidden_states = hidden_states.view(1, -1) + + cache_k = past_key_value.key_cache[self.layer_idx] + cache_v = past_key_value.value_cache[self.layer_idx] + + kv_seq_len = cache_k.shape[-2] + + import linear_q4_0 + query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + position_ids, + cache_k, cache_v, + self.q_proj.weight.qtype, + kv_seq_len, + self.head_dim) + kv_seq_len += 1 + + # update past_key_value's seem_tokens and kv caches. + if self.layer_idx == 0: + past_key_value.seen_tokens = kv_seq_len + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + + if past_key_value is not None: + if self.layer_idx is None: + invalidInputError(False, + "The cache structure has changed since version v4.36. " + f"If you are using {self.__class__.__name__} for " + "auto-regressive decodingwith k/v caching, please make sure " + "to initialize the attention class with a layer index.") + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + if use_fuse_rope: + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "mistral") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "mistral") + + if past_key_value is not None: + # update the number of seen tokens + if self.layer_idx == 0: + past_key_value.seen_tokens += key_states.shape[-2] + + # reuse k, v, self_attention + # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` + if len(past_key_value.key_cache) <= self.layer_idx: + past_key_value.key_cache.append(key_states) + past_key_value.value_cache.append(value_states) + else: + cache_k = past_key_value.key_cache[self.layer_idx] + cache_v = past_key_value.value_cache[self.layer_idx] + + if not enough_kv_room: + # allocate new + new_c_k, new_c_v = extend_kv_cache(bsz, + self.num_key_value_heads, # Support GQA + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) + + new_c_k[:] = cache_k + new_c_v[:] = cache_v + cache_k = new_c_k + cache_v = new_c_v + + key_states, value_states = append_kv_cache(cache_k, cache_v, + key_states, value_states) + + # update past_key_value + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, value_states, + bsz, q_len, kv_seq_len, + self.num_heads, self.head_dim, + self.hidden_size, attention_mask) attn_output = self.o_proj(attn_output)