optimize attention part of moonlight-14B-A3B (#12886)
This commit is contained in:
		
							parent
							
								
									dd30d12cb6
								
							
						
					
					
						commit
						ab3fc66eb7
					
				
					 4 changed files with 335 additions and 4 deletions
				
			
		| 
						 | 
					@ -1070,7 +1070,9 @@ def _optimize_pre(model, qtype=None):
 | 
				
			||||||
        model.apply(pre_register_inv_freq)
 | 
					        model.apply(pre_register_inv_freq)
 | 
				
			||||||
    elif model.config.model_type == "multi_modality":
 | 
					    elif model.config.model_type == "multi_modality":
 | 
				
			||||||
        _optimize_pre(model.language_model)
 | 
					        _optimize_pre(model.language_model)
 | 
				
			||||||
 | 
					    elif model.config.model_type == "deepseek_v3" and model.config.hidden_size == 2048:
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.deepseek import padding_mla_v_hd
 | 
				
			||||||
 | 
					        model.apply(padding_mla_v_hd)
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2023,6 +2025,15 @@ def _optimize_post(model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # llm
 | 
					        # llm
 | 
				
			||||||
        _optimize_post(model.language_model)
 | 
					        _optimize_post(model.language_model)
 | 
				
			||||||
 | 
					    elif model.config.model_type == "deepseek_v3" and model.config.hidden_size == 2048:
 | 
				
			||||||
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.common import rms_norm_forward
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.deepseek import deepseek_model_forward
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.deepseek import deepseek_attention_forward
 | 
				
			||||||
 | 
					        convert_forward(model, module.DeepseekV3RMSNorm, rms_norm_forward)
 | 
				
			||||||
 | 
					        convert_forward(model, module.DeepseekV3Model, deepseek_model_forward)
 | 
				
			||||||
 | 
					        convert_forward(model, module.DeepseekV3Attention, deepseek_attention_forward)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -95,6 +95,33 @@ def padding_attention_hd_base(module: torch.nn.Module, attention_class,
 | 
				
			||||||
        module.old_head_dim = old_head_dim
 | 
					        module.old_head_dim = old_head_dim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def padding_mla_v_hd_base(module: torch.nn.Module, attention_class):
 | 
				
			||||||
 | 
					    if (
 | 
				
			||||||
 | 
					        isinstance(attention_class, str) and module.__class__.__name__ == attention_class
 | 
				
			||||||
 | 
					        or not isinstance(attention_class, str) and isinstance(module, attention_class)
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        k_head_dim = module.q_head_dim
 | 
				
			||||||
 | 
					        v_head_dim = module.v_head_dim
 | 
				
			||||||
 | 
					        if v_head_dim < k_head_dim:
 | 
				
			||||||
 | 
					            kv_b_proj = module.kv_b_proj
 | 
				
			||||||
 | 
					            w = kv_b_proj.weight.data.view(module.num_heads,
 | 
				
			||||||
 | 
					                                           module.qk_nope_head_dim + module.v_head_dim,
 | 
				
			||||||
 | 
					                                           module.kv_lora_rank)
 | 
				
			||||||
 | 
					            k_w, v_w = w.split([module.qk_nope_head_dim, module.v_head_dim], dim=1)
 | 
				
			||||||
 | 
					            new_v_w = torch.zeros([module.num_heads, k_head_dim, module.kv_lora_rank],
 | 
				
			||||||
 | 
					                                  dtype=v_w.dtype, device=v_w.device)
 | 
				
			||||||
 | 
					            new_v_w[:, :v_head_dim, :] = v_w
 | 
				
			||||||
 | 
					            new_w = torch.cat([k_w, new_v_w], dim=1).view(-1, module.kv_lora_rank)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            new_kv_b_proj = torch.nn.Linear(0, 0, bias=False,
 | 
				
			||||||
 | 
					                                            dtype=new_w.dtype, device=new_w.device)
 | 
				
			||||||
 | 
					            new_kv_b_proj.in_features = new_w.size(1)
 | 
				
			||||||
 | 
					            new_kv_b_proj.out_features = new_w.size(0)
 | 
				
			||||||
 | 
					            new_kv_b_proj.weight = torch.nn.Parameter(new_w, False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            module.kv_b_proj = new_kv_b_proj
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def padding_states_hd(states: torch.Tensor, old_head_dim: int, new_head_dim: int):
 | 
					def padding_states_hd(states: torch.Tensor, old_head_dim: int, new_head_dim: int):
 | 
				
			||||||
    bsz, num_heads, seq_len, head_dim = states.size()
 | 
					    bsz, num_heads, seq_len, head_dim = states.size()
 | 
				
			||||||
    if head_dim == old_head_dim and old_head_dim < new_head_dim:
 | 
					    if head_dim == old_head_dim and old_head_dim < new_head_dim:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										271
									
								
								python/llm/src/ipex_llm/transformers/models/deepseek.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										271
									
								
								python/llm/src/ipex_llm/transformers/models/deepseek.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,271 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Some parts of this file is adapted from
 | 
				
			||||||
 | 
					# https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
 | 
				
			||||||
 | 
					# which is licensed under Apache License 2.0:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Optional, Tuple, List, Union
 | 
				
			||||||
 | 
					from transformers.cache_utils import Cache
 | 
				
			||||||
 | 
					from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
				
			||||||
 | 
					from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ipex_llm.utils.common.log4Error import invalidInputError
 | 
				
			||||||
 | 
					from ipex_llm.transformers.kv import DynamicNormalCache
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.common import padding_mla_v_hd_base
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import rotate_half
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def padding_mla_v_hd(module: torch.nn.Module):
 | 
				
			||||||
 | 
					    padding_mla_v_hd_base(module, "DeepseekV3Attention")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def deepseek_model_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids: torch.LongTensor = None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
 | 
					) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
				
			||||||
 | 
					    output_attentions = (
 | 
				
			||||||
 | 
					        output_attentions if output_attentions is not None
 | 
				
			||||||
 | 
					        else self.config.output_attentions
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    output_hidden_states = (
 | 
				
			||||||
 | 
					        output_hidden_states if output_hidden_states is not None
 | 
				
			||||||
 | 
					        else self.config.output_hidden_states
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return_dict = (
 | 
				
			||||||
 | 
					        return_dict if return_dict is not None else self.config.use_return_dict
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # retrieve input_ids and inputs_embeds
 | 
				
			||||||
 | 
					    invalidInputError((input_ids is None) ^ (inputs_embeds is None),
 | 
				
			||||||
 | 
					                      "You cannot specify both input_ids and inputs_embeds at the same time, "
 | 
				
			||||||
 | 
					                      "and must specify either one")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if inputs_embeds is None:
 | 
				
			||||||
 | 
					        inputs_embeds = self.embed_tokens(input_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    batch_size, seq_length = inputs_embeds.shape[:2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # IPEX-LLM OPT start: kv cache
 | 
				
			||||||
 | 
					    past_key_values_length = 0
 | 
				
			||||||
 | 
					    use_cache = True if inputs_embeds.device.type == "xpu" else use_cache
 | 
				
			||||||
 | 
					    if use_cache:
 | 
				
			||||||
 | 
					        if not isinstance(past_key_values, DynamicNormalCache):
 | 
				
			||||||
 | 
					            past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
 | 
				
			||||||
 | 
					        past_key_values_length = past_key_values.get_usable_length(seq_length)
 | 
				
			||||||
 | 
					    # IPEX-LLM OPT end: kv cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if position_ids is None:
 | 
				
			||||||
 | 
					        position_ids = torch.arange(
 | 
				
			||||||
 | 
					            past_key_values_length,
 | 
				
			||||||
 | 
					            seq_length + past_key_values_length,
 | 
				
			||||||
 | 
					            dtype=torch.long,
 | 
				
			||||||
 | 
					            device=inputs_embeds.device,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        position_ids = position_ids.unsqueeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # IPEX-LLM OPT start: fuse rope
 | 
				
			||||||
 | 
					    if inputs_embeds.device.type == "xpu" and position_ids is not None:
 | 
				
			||||||
 | 
					        cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds,
 | 
				
			||||||
 | 
					                                                       seq_length + past_key_values_length)
 | 
				
			||||||
 | 
					        cos = cos[position_ids[0]].contiguous()
 | 
				
			||||||
 | 
					        sin = sin[position_ids[0]].contiguous()
 | 
				
			||||||
 | 
					        position_embeddings = (cos, sin)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        position_embeddings = None
 | 
				
			||||||
 | 
					    # IPEX-LLM OPT end: fuse rope
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 4d mask is passed through the layers
 | 
				
			||||||
 | 
					    attention_mask = _prepare_4d_causal_attention_mask(
 | 
				
			||||||
 | 
					        attention_mask,
 | 
				
			||||||
 | 
					        (batch_size, seq_length),
 | 
				
			||||||
 | 
					        inputs_embeds,
 | 
				
			||||||
 | 
					        past_key_values_length,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # embed positions
 | 
				
			||||||
 | 
					    hidden_states = inputs_embeds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # decoder layers
 | 
				
			||||||
 | 
					    all_hidden_states = () if output_hidden_states else None
 | 
				
			||||||
 | 
					    all_self_attns = () if output_attentions else None
 | 
				
			||||||
 | 
					    next_decoder_cache = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for decoder_layer in self.layers:
 | 
				
			||||||
 | 
					        if output_hidden_states:
 | 
				
			||||||
 | 
					            all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        layer_outputs = decoder_layer(
 | 
				
			||||||
 | 
					            hidden_states,
 | 
				
			||||||
 | 
					            attention_mask=attention_mask,
 | 
				
			||||||
 | 
					            position_ids=position_ids,
 | 
				
			||||||
 | 
					            past_key_value=past_key_values,
 | 
				
			||||||
 | 
					            output_attentions=output_attentions,
 | 
				
			||||||
 | 
					            use_cache=use_cache,
 | 
				
			||||||
 | 
					            position_embeddings=position_embeddings,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = layer_outputs[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if use_cache:
 | 
				
			||||||
 | 
					            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if output_attentions:
 | 
				
			||||||
 | 
					            all_self_attns += (layer_outputs[1],)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = self.norm(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # add hidden states from the last decoder layer
 | 
				
			||||||
 | 
					    if output_hidden_states:
 | 
				
			||||||
 | 
					        all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    next_cache = next_decoder_cache
 | 
				
			||||||
 | 
					    if not return_dict:
 | 
				
			||||||
 | 
					        return tuple(
 | 
				
			||||||
 | 
					            v
 | 
				
			||||||
 | 
					            for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
 | 
				
			||||||
 | 
					            if v is not None
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    return BaseModelOutputWithPast(
 | 
				
			||||||
 | 
					        last_hidden_state=hidden_states,
 | 
				
			||||||
 | 
					        past_key_values=next_cache,
 | 
				
			||||||
 | 
					        hidden_states=all_hidden_states,
 | 
				
			||||||
 | 
					        attentions=all_self_attns,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
 | 
				
			||||||
 | 
					    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
 | 
				
			||||||
 | 
					    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    b, h, s, d = q.shape
 | 
				
			||||||
 | 
					    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    b, h, s, d = k.shape
 | 
				
			||||||
 | 
					    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    q_embed = (q * cos) + (rotate_half(q) * sin)
 | 
				
			||||||
 | 
					    k_embed = (k * cos) + (rotate_half(k) * sin)
 | 
				
			||||||
 | 
					    return q_embed, k_embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def deepseek_attention_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_value: Optional[Cache] = None,
 | 
				
			||||||
 | 
					    output_attentions: bool = False,
 | 
				
			||||||
 | 
					    use_cache: bool = False,
 | 
				
			||||||
 | 
					    **kwargs,
 | 
				
			||||||
 | 
					) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
				
			||||||
 | 
					    if "padding_mask" in kwargs:
 | 
				
			||||||
 | 
					        warnings.warn(
 | 
				
			||||||
 | 
					            "Passing `padding_mask` is deprecated and will be removed in v4.37. "
 | 
				
			||||||
 | 
					            "Please make sure use `attention_mask` instead.`"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self.q_lora_rank is None:
 | 
				
			||||||
 | 
					        q = self.q_proj(hidden_states)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
 | 
				
			||||||
 | 
					    q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
 | 
				
			||||||
 | 
					    compressed_kv, k_pe = torch.split(
 | 
				
			||||||
 | 
					        compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					    kv = (
 | 
				
			||||||
 | 
					        self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
 | 
				
			||||||
 | 
					        .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.q_head_dim)
 | 
				
			||||||
 | 
					        .transpose(1, 2)
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    k_nope, value_states = torch.split(
 | 
				
			||||||
 | 
					        kv, [self.qk_nope_head_dim, self.q_head_dim], dim=-1
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    kv_seq_len = value_states.shape[-2]
 | 
				
			||||||
 | 
					    if past_key_value is not None:
 | 
				
			||||||
 | 
					        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    position_embeddings = kwargs.get("position_embeddings", None)
 | 
				
			||||||
 | 
					    if position_embeddings is not None:
 | 
				
			||||||
 | 
					        query_states = q
 | 
				
			||||||
 | 
					        key_states = torch.cat(
 | 
				
			||||||
 | 
					            [k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
 | 
				
			||||||
 | 
					            dim=-1
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        import xe_addons
 | 
				
			||||||
 | 
					        cos, sin = position_embeddings
 | 
				
			||||||
 | 
					        xe_addons.rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:],
 | 
				
			||||||
 | 
					                                                 key_states[:, :, :, self.qk_nope_head_dim:],
 | 
				
			||||||
 | 
					                                                 cos, sin, True)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        q_nope, q_pe = torch.split(
 | 
				
			||||||
 | 
					            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
				
			||||||
 | 
					        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
 | 
				
			||||||
 | 
					        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
 | 
				
			||||||
 | 
					        query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
 | 
				
			||||||
 | 
					        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
 | 
				
			||||||
 | 
					        key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if past_key_value is not None:
 | 
				
			||||||
 | 
					        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
 | 
					                                                         self.layer_idx, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_weights = None
 | 
				
			||||||
 | 
					    attn_output = scaled_dot_product_attention(
 | 
				
			||||||
 | 
					        query_states, key_states, value_states,
 | 
				
			||||||
 | 
					        attention_mask, q_len == kv_seq_len, self.softmax_scale
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    attn_output = attn_output[:, :, :, :self.v_head_dim]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = self.o_proj(attn_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not output_attentions:
 | 
				
			||||||
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return attn_output, attn_weights, past_key_value
 | 
				
			||||||
| 
						 | 
					@ -1,3 +1,25 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Some parts of this file is adapted from
 | 
				
			||||||
 | 
					# https://hf-mirror.com/openbmb/MiniCPM3-4B/blob/main/modeling_minicpm.py
 | 
				
			||||||
 | 
					# which is licensed under Apache License 2.0:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -122,9 +144,6 @@ def minicpm3_attention_forward(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
 | 
					    q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
 | 
				
			||||||
    q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
 | 
					    q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
 | 
				
			||||||
    q_nope, q_pe = torch.split(
 | 
					 | 
				
			||||||
        q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
 | 
					    compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
 | 
				
			||||||
    compressed_kv, k_pe = torch.split(
 | 
					    compressed_kv, k_pe = torch.split(
 | 
				
			||||||
| 
						 | 
					@ -169,6 +188,9 @@ def minicpm3_attention_forward(
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            invalidInputError(f"unknown rope method: {self.rotary_emb.__class__.__name__}")
 | 
					            invalidInputError(f"unknown rope method: {self.rotary_emb.__class__.__name__}")
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
 | 
					        q_nope, q_pe = torch.split(
 | 
				
			||||||
 | 
					            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
					        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
				
			||||||
        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
 | 
					        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue