343 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			343 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#
 | 
						|
# Copyright 2016 The BigDL Authors.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
#
 | 
						|
# Some parts of this file is adapted from
 | 
						|
# https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
 | 
						|
# which is licensed under Apache License 2.0:
 | 
						|
#
 | 
						|
# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE
 | 
						|
#
 | 
						|
 | 
						|
import torch
 | 
						|
import warnings
 | 
						|
 | 
						|
from typing import Optional, Tuple, List, Union
 | 
						|
from transformers.cache_utils import Cache
 | 
						|
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
						|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 | 
						|
 | 
						|
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
						|
from ipex_llm.transformers.kv import DynamicNormalCache
 | 
						|
from ipex_llm.transformers.models.common import padding_mla_v_hd_base
 | 
						|
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
						|
from ipex_llm.transformers.models.utils import rotate_half, use_fuse_moe
 | 
						|
 | 
						|
 | 
						|
def padding_mla_v_hd(module: torch.nn.Module):
 | 
						|
    padding_mla_v_hd_base(module, "DeepseekV3Attention")
 | 
						|
 | 
						|
 | 
						|
def deepseek_model_forward(
 | 
						|
    self,
 | 
						|
    input_ids: torch.LongTensor = None,
 | 
						|
    attention_mask: Optional[torch.Tensor] = None,
 | 
						|
    position_ids: Optional[torch.LongTensor] = None,
 | 
						|
    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
						|
    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
						|
    use_cache: Optional[bool] = None,
 | 
						|
    output_attentions: Optional[bool] = None,
 | 
						|
    output_hidden_states: Optional[bool] = None,
 | 
						|
    return_dict: Optional[bool] = None,
 | 
						|
) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
						|
    output_attentions = (
 | 
						|
        output_attentions if output_attentions is not None
 | 
						|
        else self.config.output_attentions
 | 
						|
    )
 | 
						|
    output_hidden_states = (
 | 
						|
        output_hidden_states if output_hidden_states is not None
 | 
						|
        else self.config.output_hidden_states
 | 
						|
    )
 | 
						|
 | 
						|
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
						|
 | 
						|
    return_dict = (
 | 
						|
        return_dict if return_dict is not None else self.config.use_return_dict
 | 
						|
    )
 | 
						|
 | 
						|
    # retrieve input_ids and inputs_embeds
 | 
						|
    invalidInputError((input_ids is None) ^ (inputs_embeds is None),
 | 
						|
                      "You cannot specify both input_ids and inputs_embeds at the same time, "
 | 
						|
                      "and must specify either one")
 | 
						|
 | 
						|
    if inputs_embeds is None:
 | 
						|
        inputs_embeds = self.embed_tokens(input_ids)
 | 
						|
 | 
						|
    batch_size, seq_length = inputs_embeds.shape[:2]
 | 
						|
 | 
						|
    # IPEX-LLM OPT start: kv cache
 | 
						|
    past_key_values_length = 0
 | 
						|
    use_cache = True if inputs_embeds.device.type == "xpu" else use_cache
 | 
						|
    if use_cache:
 | 
						|
        if not isinstance(past_key_values, DynamicNormalCache):
 | 
						|
            past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
 | 
						|
        past_key_values_length = past_key_values.get_usable_length(seq_length)
 | 
						|
    # IPEX-LLM OPT end: kv cache
 | 
						|
 | 
						|
    if position_ids is None:
 | 
						|
        position_ids = torch.arange(
 | 
						|
            past_key_values_length,
 | 
						|
            seq_length + past_key_values_length,
 | 
						|
            dtype=torch.long,
 | 
						|
            device=inputs_embeds.device,
 | 
						|
        )
 | 
						|
        position_ids = position_ids.unsqueeze(0)
 | 
						|
 | 
						|
    # IPEX-LLM OPT start: fuse rope
 | 
						|
    if inputs_embeds.device.type == "xpu" and position_ids is not None:
 | 
						|
        cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds,
 | 
						|
                                                       seq_length + past_key_values_length)
 | 
						|
        cos = cos[position_ids[0]].contiguous()
 | 
						|
        sin = sin[position_ids[0]].contiguous()
 | 
						|
        position_embeddings = (cos, sin)
 | 
						|
    else:
 | 
						|
        position_embeddings = None
 | 
						|
    # IPEX-LLM OPT end: fuse rope
 | 
						|
 | 
						|
    # 4d mask is passed through the layers
 | 
						|
    attention_mask = _prepare_4d_causal_attention_mask(
 | 
						|
        attention_mask,
 | 
						|
        (batch_size, seq_length),
 | 
						|
        inputs_embeds,
 | 
						|
        past_key_values_length,
 | 
						|
    )
 | 
						|
 | 
						|
    # embed positions
 | 
						|
    hidden_states = inputs_embeds
 | 
						|
 | 
						|
    # decoder layers
 | 
						|
    all_hidden_states = () if output_hidden_states else None
 | 
						|
    all_self_attns = () if output_attentions else None
 | 
						|
    next_decoder_cache = None
 | 
						|
 | 
						|
    for decoder_layer in self.layers:
 | 
						|
        if output_hidden_states:
 | 
						|
            all_hidden_states += (hidden_states,)
 | 
						|
 | 
						|
        layer_outputs = decoder_layer(
 | 
						|
            hidden_states,
 | 
						|
            attention_mask=attention_mask,
 | 
						|
            position_ids=position_ids,
 | 
						|
            past_key_value=past_key_values,
 | 
						|
            output_attentions=output_attentions,
 | 
						|
            use_cache=use_cache,
 | 
						|
            position_embeddings=position_embeddings,
 | 
						|
        )
 | 
						|
 | 
						|
        hidden_states = layer_outputs[0]
 | 
						|
 | 
						|
        if use_cache:
 | 
						|
            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
						|
 | 
						|
        if output_attentions:
 | 
						|
            all_self_attns += (layer_outputs[1],)
 | 
						|
 | 
						|
    hidden_states = self.norm(hidden_states)
 | 
						|
 | 
						|
    # add hidden states from the last decoder layer
 | 
						|
    if output_hidden_states:
 | 
						|
        all_hidden_states += (hidden_states,)
 | 
						|
 | 
						|
    next_cache = next_decoder_cache
 | 
						|
    if not return_dict:
 | 
						|
        return tuple(
 | 
						|
            v
 | 
						|
            for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
 | 
						|
            if v is not None
 | 
						|
        )
 | 
						|
    return BaseModelOutputWithPast(
 | 
						|
        last_hidden_state=hidden_states,
 | 
						|
        past_key_values=next_cache,
 | 
						|
        hidden_states=all_hidden_states,
 | 
						|
        attentions=all_self_attns,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
 | 
						|
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
 | 
						|
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
 | 
						|
 | 
						|
    b, h, s, d = q.shape
 | 
						|
    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
 | 
						|
 | 
						|
    b, h, s, d = k.shape
 | 
						|
    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
 | 
						|
 | 
						|
    q_embed = (q * cos) + (rotate_half(q) * sin)
 | 
						|
    k_embed = (k * cos) + (rotate_half(k) * sin)
 | 
						|
    return q_embed, k_embed
 | 
						|
 | 
						|
 | 
						|
def deepseek_attention_forward(
 | 
						|
    self,
 | 
						|
    hidden_states: torch.Tensor,
 | 
						|
    attention_mask: Optional[torch.Tensor] = None,
 | 
						|
    position_ids: Optional[torch.LongTensor] = None,
 | 
						|
    past_key_value: Optional[Cache] = None,
 | 
						|
    output_attentions: bool = False,
 | 
						|
    use_cache: bool = False,
 | 
						|
    **kwargs,
 | 
						|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
						|
    if "padding_mask" in kwargs:
 | 
						|
        warnings.warn(
 | 
						|
            "Passing `padding_mask` is deprecated and will be removed in v4.37. "
 | 
						|
            "Please make sure use `attention_mask` instead.`"
 | 
						|
        )
 | 
						|
 | 
						|
    bsz, q_len, _ = hidden_states.size()
 | 
						|
 | 
						|
    if self.q_lora_rank is None:
 | 
						|
        q = self.q_proj(hidden_states)
 | 
						|
    else:
 | 
						|
        q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
 | 
						|
    q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
 | 
						|
 | 
						|
    compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
 | 
						|
    compressed_kv, k_pe = torch.split(
 | 
						|
        compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
 | 
						|
    )
 | 
						|
    k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
 | 
						|
    kv = (
 | 
						|
        self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
 | 
						|
        .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.q_head_dim)
 | 
						|
        .transpose(1, 2)
 | 
						|
    )
 | 
						|
 | 
						|
    k_nope, value_states = torch.split(
 | 
						|
        kv, [self.qk_nope_head_dim, self.q_head_dim], dim=-1
 | 
						|
    )
 | 
						|
    kv_seq_len = value_states.shape[-2]
 | 
						|
    if past_key_value is not None:
 | 
						|
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
						|
 | 
						|
    position_embeddings = kwargs.get("position_embeddings", None)
 | 
						|
    if position_embeddings is not None:
 | 
						|
        query_states = q
 | 
						|
        key_states = torch.cat(
 | 
						|
            [k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
 | 
						|
            dim=-1
 | 
						|
        )
 | 
						|
        cos, sin = position_embeddings
 | 
						|
        from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
 | 
						|
        rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:],
 | 
						|
                                       key_states[:, :, :, self.qk_nope_head_dim:],
 | 
						|
                                       cos, sin, True)
 | 
						|
    else:
 | 
						|
        q_nope, q_pe = torch.split(
 | 
						|
            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
 | 
						|
        )
 | 
						|
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
						|
        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
 | 
						|
 | 
						|
        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
 | 
						|
        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
 | 
						|
        query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
 | 
						|
 | 
						|
        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
 | 
						|
        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
 | 
						|
        key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
 | 
						|
 | 
						|
    if past_key_value is not None:
 | 
						|
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
						|
                                                         self.layer_idx, None)
 | 
						|
 | 
						|
    attn_weights = None
 | 
						|
    attn_output = scaled_dot_product_attention(
 | 
						|
        query_states, key_states, value_states,
 | 
						|
        attention_mask, q_len == kv_seq_len, self.softmax_scale
 | 
						|
    )
 | 
						|
    attn_output = attn_output[:, :, :, :self.v_head_dim]
 | 
						|
 | 
						|
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
						|
 | 
						|
    attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
 | 
						|
 | 
						|
    attn_output = self.o_proj(attn_output)
 | 
						|
 | 
						|
    if not output_attentions:
 | 
						|
        attn_weights = None
 | 
						|
 | 
						|
    return attn_output, attn_weights, past_key_value
 | 
						|
 | 
						|
 | 
						|
def fuse_gate_forward(self, x: torch.Tensor):
 | 
						|
    if x.device.type == "xpu" and x.dtype in [torch.float, torch.half]:
 | 
						|
        x = x.view(-1, x.size(-1))
 | 
						|
        logits = torch.nn.functional.linear(
 | 
						|
            x.type(torch.float32), self.weight.type(torch.float32), None
 | 
						|
        )
 | 
						|
        scores = logits.sigmoid()
 | 
						|
 | 
						|
        from ipex_llm.transformers.models.common import moe_group_topk
 | 
						|
        topk_idx, topk_weight = moe_group_topk(
 | 
						|
            scores, self.e_score_correction_bias,
 | 
						|
            self.n_group, self.topk_group, self.top_k,
 | 
						|
            self.norm_topk_prob, self.routed_scaling_factor
 | 
						|
        )
 | 
						|
    else:
 | 
						|
        topk_idx, topk_weight = self(x)
 | 
						|
    return topk_idx, topk_weight.to(x.dtype)
 | 
						|
 | 
						|
 | 
						|
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
 | 
						|
    qtype = self.experts[0].down_proj.qtype
 | 
						|
    if use_fuse_moe(x, qtype):
 | 
						|
        if getattr(self, "gates", None) is None:
 | 
						|
            gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts]
 | 
						|
            up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts]
 | 
						|
            down_addrs = [expert.down_proj.weight.data_ptr() for expert in self.experts]
 | 
						|
            gates = torch.tensor(gate_addrs, dtype=torch.uint64, device=x.device)
 | 
						|
            ups = torch.tensor(up_addrs, dtype=torch.uint64, device=x.device)
 | 
						|
            downs = torch.tensor(down_addrs, dtype=torch.uint64, device=x.device)
 | 
						|
            self.register_buffer("gates", gates, persistent=False)
 | 
						|
            self.register_buffer("ups", ups, persistent=False)
 | 
						|
            self.register_buffer("downs", downs, persistent=False)
 | 
						|
 | 
						|
        import xe_linear
 | 
						|
        final_out = xe_linear.moe_forward_vec(
 | 
						|
            x, topk_ids, topk_weight, self.gates, self.ups, self.downs,
 | 
						|
            x.size(-1), self.experts[0].intermediate_size, qtype
 | 
						|
        )
 | 
						|
    else:
 | 
						|
        idxs = topk_ids.flatten().tolist()
 | 
						|
        outputs = []
 | 
						|
        for i in idxs:
 | 
						|
            expert = self.experts[i]
 | 
						|
            expert_out = expert(x)
 | 
						|
            outputs.append(expert_out)
 | 
						|
        outs = torch.cat(outputs, dim=0)
 | 
						|
        reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1)
 | 
						|
        final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
 | 
						|
    return final_out
 | 
						|
 | 
						|
 | 
						|
def deepseek_moe_forward(self, hidden_states: torch.Tensor):
 | 
						|
    identity = hidden_states
 | 
						|
    orig_shape = hidden_states.shape
 | 
						|
    # IPEX-LLM OPT start: fuse grouped topk in gate forward
 | 
						|
    topk_idx, topk_weight = fuse_gate_forward(self.gate, hidden_states)
 | 
						|
    # IPEX-LLM OPT end
 | 
						|
    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
 | 
						|
    flat_topk_idx = topk_idx.view(-1)
 | 
						|
    if not self.training:
 | 
						|
        # IPEX-LLM OPT start: add special moe_infer implementation for decoding
 | 
						|
        if topk_idx.size(0) == 1 and self.ep_size == 1:
 | 
						|
            y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
 | 
						|
        else:
 | 
						|
            y = self.moe_infer(hidden_states, topk_idx, topk_weight)
 | 
						|
        y = y.view(*orig_shape)
 | 
						|
        # IPEX-LLM OPT end
 | 
						|
    if self.config.n_shared_experts is not None:
 | 
						|
        y = y + self.shared_experts(identity)
 | 
						|
    return y
 |