Optimize StableLM (#10619)
* Initial commit for stablelm optimizations * Small style fix * add dependency * Add mlp optimizations * Small fix * add attention forward * Remove quantize kv for now as head_dim=80 * Add merged qkv * fix lisence * Python style fix --------- Co-authored-by: qiuxin2012 <qiuxin2012cs@gmail.com>
This commit is contained in:
		
							parent
							
								
									27be448920
								
							
						
					
					
						commit
						fd384ddfb8
					
				
					 3 changed files with 261 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -632,6 +632,10 @@ def _optimize_pre(model):
 | 
			
		|||
                module.rope_base = rope_base
 | 
			
		||||
                del module.c_attn
 | 
			
		||||
        model.apply(split_qkv_proj_func)
 | 
			
		||||
    if model.config.model_type == "stablelm":
 | 
			
		||||
        from ipex_llm.transformers.models.stablelm import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1336,5 +1340,16 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        convert_forward(model,
 | 
			
		||||
                        module.BertEncoder,
 | 
			
		||||
                        encoder_forward)
 | 
			
		||||
    elif model.config.model_type == 'stablelm':
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.stablelm import stablelm_attention_forward
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.StableLmAttention,
 | 
			
		||||
                        stablelm_attention_forward
 | 
			
		||||
                        )
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.StableLmMLP,
 | 
			
		||||
                        llama_mlp_forward)
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										244
									
								
								python/llm/src/ipex_llm/transformers/models/stablelm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										244
									
								
								python/llm/src/ipex_llm/transformers/models/stablelm.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,244 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.38.0/src/transformers/models/stablelm/modeling_stablelm.py
 | 
			
		||||
# which is licensed under Apache License 2.0:
 | 
			
		||||
#
 | 
			
		||||
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
 | 
			
		||||
# and OPT implementations in this library. It has been modified from its
 | 
			
		||||
# original forms to accommodate minor architectural differences compared
 | 
			
		||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch import nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from transformers.models.stablelm.modeling_stablelm import StableLmAttention
 | 
			
		||||
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
 | 
			
		||||
    apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv
 | 
			
		||||
try:
 | 
			
		||||
    from transformers.cache_utils import Cache
 | 
			
		||||
except ImportError:
 | 
			
		||||
    Cache = Tuple[torch.Tensor]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
    if isinstance(module, StableLmAttention):
 | 
			
		||||
        new_weight = torch.cat([
 | 
			
		||||
            module.q_proj.weight.data,
 | 
			
		||||
            module.k_proj.weight.data,
 | 
			
		||||
            module.v_proj.weight.data,
 | 
			
		||||
        ], dim=0)
 | 
			
		||||
 | 
			
		||||
        qkv_proj = torch.nn.Linear(0, 0, bias=False)
 | 
			
		||||
        qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
 | 
			
		||||
        qkv_proj.in_features = new_weight.size(1)
 | 
			
		||||
        qkv_proj.out_features = new_weight.size(0)
 | 
			
		||||
        module.qkv_proj = qkv_proj
 | 
			
		||||
 | 
			
		||||
        del module.q_proj, module.k_proj, module.v_proj
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def stablelm_attention_forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor]=None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor]=None,
 | 
			
		||||
        past_key_value: Optional[Cache]=None,
 | 
			
		||||
        output_attentions: bool=False,
 | 
			
		||||
        use_cache: bool=False,
 | 
			
		||||
        **kwargs
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
 | 
			
		||||
 | 
			
		||||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
    device = hidden_states.device
 | 
			
		||||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
			
		||||
 | 
			
		||||
    qkv = self.qkv_proj(hidden_states)
 | 
			
		||||
    qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
 | 
			
		||||
    qkv = qkv.transpose(1, 2)
 | 
			
		||||
    query_states, key_states, value_states = qkv.split([self.num_heads,
 | 
			
		||||
                                                        self.num_heads,
 | 
			
		||||
                                                        self.num_heads], dim=1)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        if self.layer_idx is None:
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              "The cache structure has changed since version v4.36. "
 | 
			
		||||
                              f"If you are using {self.__class__.__name__} for "
 | 
			
		||||
                              "auto-regressive decodingwith k/v caching, please make sure "
 | 
			
		||||
                              "to initialize the attention class with a layer index.")
 | 
			
		||||
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
			
		||||
 | 
			
		||||
    # Partial rotary embedding
 | 
			
		||||
    query_rot, query_pass = (
 | 
			
		||||
        query_states[..., : self.rotary_emb.dim],
 | 
			
		||||
        query_states[..., self.rotary_emb.dim:],
 | 
			
		||||
    )
 | 
			
		||||
    key_rot, key_pass = (
 | 
			
		||||
        key_states[..., : self.rotary_emb.dim],
 | 
			
		||||
        key_states[..., self.rotary_emb.dim:],
 | 
			
		||||
    )
 | 
			
		||||
    if use_fuse_rope:
 | 
			
		||||
        query_rot, key_rot = apply_rotary_pos_emb_no_cache_xpu(query_rot,
 | 
			
		||||
                                                               key_rot,
 | 
			
		||||
                                                               position_ids,
 | 
			
		||||
                                                               "stablelm")
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
 | 
			
		||||
        query_rot, key_rot = apply_rotary_pos_emb(query_rot,
 | 
			
		||||
                                                  key_rot,
 | 
			
		||||
                                                  cos,
 | 
			
		||||
                                                  sin,
 | 
			
		||||
                                                  position_ids,
 | 
			
		||||
                                                  "stablelm")
 | 
			
		||||
 | 
			
		||||
    # [batch_size, seq_length, num_heads, head_dim]
 | 
			
		||||
    query_states = torch.cat((query_rot, query_pass), dim=-1)
 | 
			
		||||
    key_states = torch.cat((key_rot, key_pass), dim=-1)
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        # update the number of seen tokens
 | 
			
		||||
        if self.layer_idx == 0:
 | 
			
		||||
            past_key_value.seen_tokens += key_states.shape[-2]
 | 
			
		||||
 | 
			
		||||
        # reuse k, v, self_attention
 | 
			
		||||
        # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
 | 
			
		||||
        if len(past_key_value.key_cache) <= self.layer_idx:
 | 
			
		||||
            past_key_value.key_cache.append(key_states)
 | 
			
		||||
            past_key_value.value_cache.append(value_states)
 | 
			
		||||
        else:
 | 
			
		||||
            cache_k = past_key_value.key_cache[self.layer_idx]
 | 
			
		||||
            cache_v = past_key_value.value_cache[self.layer_idx]
 | 
			
		||||
 | 
			
		||||
            if not enough_kv_room:
 | 
			
		||||
                # allocate new
 | 
			
		||||
                new_c_k, new_c_v = extend_kv_cache(bsz,
 | 
			
		||||
                                                   self.num_key_value_heads,  # Support GQA
 | 
			
		||||
                                                   self.head_dim,
 | 
			
		||||
                                                   cache_k.size(2),
 | 
			
		||||
                                                   kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
			
		||||
                                                   dtype=cache_k.dtype,
 | 
			
		||||
                                                   device=device)
 | 
			
		||||
 | 
			
		||||
                new_c_k[:] = cache_k
 | 
			
		||||
                new_c_v[:] = cache_v
 | 
			
		||||
                cache_k = new_c_k
 | 
			
		||||
                cache_v = new_c_v
 | 
			
		||||
 | 
			
		||||
            key_states, value_states = append_kv_cache(cache_k, cache_v,
 | 
			
		||||
                                                       key_states, value_states)
 | 
			
		||||
 | 
			
		||||
            # update past_key_value
 | 
			
		||||
            past_key_value.key_cache[self.layer_idx] = key_states
 | 
			
		||||
            past_key_value.value_cache[self.layer_idx] = value_states
 | 
			
		||||
 | 
			
		||||
    # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_flash_attention(query_states, key_states, attention_mask):
 | 
			
		||||
        attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16),
 | 
			
		||||
                                                     key_states.to(device, dtype=torch.float16),
 | 
			
		||||
                                                     value_states.to(device, dtype=torch.float16),
 | 
			
		||||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states, attention_mask):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
			
		||||
                                                    key_states,
 | 
			
		||||
                                                    value_states)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
        attn_weights = torch.matmul(
 | 
			
		||||
            query_states,
 | 
			
		||||
            key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 | 
			
		||||
            invalidInputError(
 | 
			
		||||
                False,
 | 
			
		||||
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},"
 | 
			
		||||
                f" but is {attn_weights.size()}"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
 | 
			
		||||
                invalidInputError(
 | 
			
		||||
                    False,
 | 
			
		||||
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
 | 
			
		||||
                    f" but is {attention_mask.size()}"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = \
 | 
			
		||||
            nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype)
 | 
			
		||||
        attn_weights = self.attention_dropout(attn_weights)
 | 
			
		||||
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 | 
			
		||||
            invalidInputError(
 | 
			
		||||
                False,
 | 
			
		||||
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)},"
 | 
			
		||||
                f" but is {attn_output.size()}"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
    attn_output = self.o_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
    if not output_attentions:
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
 | 
			
		||||
    return attn_output.to(original_dtype), attn_weights, past_key_value
 | 
			
		||||
| 
						 | 
				
			
			@ -168,7 +168,7 @@ def rotate_every_two(x):
 | 
			
		|||
 | 
			
		||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
 | 
			
		||||
    if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
 | 
			
		||||
                        "mixtral", "qwen2", "yuan"]:
 | 
			
		||||
                        "mixtral", "qwen2", "yuan", "stablelm"]:
 | 
			
		||||
        # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
 | 
			
		||||
        cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
			
		||||
        sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
			
		||||
| 
						 | 
				
			
			@ -207,7 +207,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_the
 | 
			
		|||
    q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
 | 
			
		||||
    k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
 | 
			
		||||
    if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
 | 
			
		||||
                        "mixtral"]:
 | 
			
		||||
                        "mixtral", "stablelm"]:
 | 
			
		||||
        linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids,
 | 
			
		||||
                                                        q_embed, k_embed, rope_theta)
 | 
			
		||||
        return q_embed, k_embed
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue