LLM: add fuse optimization for Mistral. (#9184)
* add fuse optimization for mistral. * fix. * fix * fix style. * fix. * fix error. * fix style. * fix style.
This commit is contained in:
		
							parent
							
								
									49e1381c7f
								
							
						
					
					
						commit
						5ca8a851e9
					
				
					 3 changed files with 160 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -348,4 +348,15 @@ def optimize(model):
 | 
			
		|||
        convert_forward(model,
 | 
			
		||||
                        module.AquilaRMSNorm,
 | 
			
		||||
                        llama_rms_norm_forward)
 | 
			
		||||
    elif model.config.model_type == "mistral":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from bigdl.llm.transformers.models.mistral import mistral_attention_forward
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.MistralAttention,
 | 
			
		||||
                        mistral_attention_forward
 | 
			
		||||
                        )
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.MistralRMSNorm,
 | 
			
		||||
                        llama_rms_norm_forward)
 | 
			
		||||
    return model
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										148
									
								
								python/llm/src/bigdl/llm/transformers/models/mistral.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								python/llm/src/bigdl/llm/transformers/models/mistral.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,148 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py
 | 
			
		||||
#
 | 
			
		||||
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
 | 
			
		||||
# and OPT implementations in this library. It has been modified from its
 | 
			
		||||
# original forms to accommodate minor architectural differences compared
 | 
			
		||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
""" PyTorch Mistral model."""
 | 
			
		||||
import math
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch import nn
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
 | 
			
		||||
    apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
 | 
			
		||||
    The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
 | 
			
		||||
    to (batch, num_attention_heads, seqlen, head_dim)
 | 
			
		||||
    """
 | 
			
		||||
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
 | 
			
		||||
    if n_rep == 1:
 | 
			
		||||
        return hidden_states
 | 
			
		||||
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
 | 
			
		||||
                                                           n_rep, slen, head_dim)
 | 
			
		||||
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def mistral_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor]=None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor]=None,
 | 
			
		||||
    past_key_value: Optional[Tuple[torch.Tensor]]=None,
 | 
			
		||||
    output_attentions: bool=False,
 | 
			
		||||
    use_cache: bool=False,
 | 
			
		||||
    padding_mask: Optional[torch.Tensor]=None,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
 | 
			
		||||
    query_states = self.q_proj(hidden_states)
 | 
			
		||||
    key_states = self.k_proj(hidden_states)
 | 
			
		||||
    value_states = self.v_proj(hidden_states)
 | 
			
		||||
 | 
			
		||||
    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    key_states = key_states.view(bsz, q_len,
 | 
			
		||||
                                 self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    value_states = value_states.view(bsz, q_len,
 | 
			
		||||
                                     self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
    if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                     key_states,
 | 
			
		||||
                                                                     position_ids,
 | 
			
		||||
                                                                     "mistral")
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
                                                        cos, sin, position_ids, "mistral")
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        # reuse k, v, self_attention
 | 
			
		||||
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
 | 
			
		||||
        value_states = torch.cat([past_key_value[1], value_states], dim=2)
 | 
			
		||||
 | 
			
		||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
			
		||||
 | 
			
		||||
    # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
    attn_weights = torch.matmul(
 | 
			
		||||
        query_states,
 | 
			
		||||
        key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 | 
			
		||||
        invalidInputError(
 | 
			
		||||
            False,
 | 
			
		||||
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},"
 | 
			
		||||
            f" but is {attn_weights.size()}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if attention_mask is not None:
 | 
			
		||||
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
 | 
			
		||||
            invalidInputError(
 | 
			
		||||
                False,
 | 
			
		||||
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
 | 
			
		||||
                f" but is {attention_mask.size()}"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
    # upcast attention to fp32
 | 
			
		||||
    attn_weights = nn.functional.\
 | 
			
		||||
        softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 | 
			
		||||
        invalidInputError(
 | 
			
		||||
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)},"
 | 
			
		||||
            f" but is {attn_output.size()}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
 | 
			
		||||
    attn_output = self.o_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
    if not output_attentions:
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
 | 
			
		||||
    return attn_output, attn_weights, past_key_value
 | 
			
		||||
| 
						 | 
				
			
			@ -98,7 +98,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
 | 
			
		|||
    import linear_q4_0
 | 
			
		||||
    q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
 | 
			
		||||
    k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
 | 
			
		||||
    if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox"]:
 | 
			
		||||
    if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral"]:
 | 
			
		||||
        linear_q4_0.apply_rotary_embedding_half_qk(q, k, position_ids, q_embed, k_embed)
 | 
			
		||||
        return q_embed, k_embed
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue