ipex-llm/python/llm/src/ipex_llm/transformers/models/deepseek.py
2025-04-22 14:45:31 +08:00

343 lines
13 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
# which is licensed under Apache License 2.0:
#
# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE
#
import torch
import warnings
from typing import Optional, Tuple, List, Union
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.kv import DynamicNormalCache
from ipex_llm.transformers.models.common import padding_mla_v_hd_base
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import rotate_half, use_fuse_moe
def padding_mla_v_hd(module: torch.nn.Module):
padding_mla_v_hd_base(module, "DeepseekV3Attention")
def deepseek_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
invalidInputError((input_ids is None) ^ (inputs_embeds is None),
"You cannot specify both input_ids and inputs_embeds at the same time, "
"and must specify either one")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = inputs_embeds.shape[:2]
# IPEX-LLM OPT start: kv cache
past_key_values_length = 0
use_cache = True if inputs_embeds.device.type == "xpu" else use_cache
if use_cache:
if not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
# IPEX-LLM OPT end: kv cache
if position_ids is None:
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=inputs_embeds.device,
)
position_ids = position_ids.unsqueeze(0)
# IPEX-LLM OPT start: fuse rope
if inputs_embeds.device.type == "xpu" and position_ids is not None:
cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds,
seq_length + past_key_values_length)
cos = cos[position_ids[0]].contiguous()
sin = sin[position_ids[0]].contiguous()
position_embeddings = (cos, sin)
else:
position_embeddings = None
# IPEX-LLM OPT end: fuse rope
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
# embed positions
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def deepseek_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
"Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.q_head_dim)
.transpose(1, 2)
)
k_nope, value_states = torch.split(
kv, [self.qk_nope_head_dim, self.q_head_dim], dim=-1
)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
position_embeddings = kwargs.get("position_embeddings", None)
if position_embeddings is not None:
query_states = q
key_states = torch.cat(
[k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
dim=-1
)
cos, sin = position_embeddings
from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:],
key_states[:, :, :, self.qk_nope_head_dim:],
cos, sin, True)
else:
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)
attn_weights = None
attn_output = scaled_dot_product_attention(
query_states, key_states, value_states,
attention_mask, q_len == kv_seq_len, self.softmax_scale
)
attn_output = attn_output[:, :, :, :self.v_head_dim]
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def fuse_gate_forward(self, x: torch.Tensor):
if x.device.type == "xpu" and x.dtype in [torch.float, torch.half]:
x = x.view(-1, x.size(-1))
logits = torch.nn.functional.linear(
x.type(torch.float32), self.weight.type(torch.float32), None
)
scores = logits.sigmoid()
from ipex_llm.transformers.models.common import moe_group_topk
topk_idx, topk_weight = moe_group_topk(
scores, self.e_score_correction_bias,
self.n_group, self.topk_group, self.top_k,
self.norm_topk_prob, self.routed_scaling_factor
)
else:
topk_idx, topk_weight = self(x)
return topk_idx, topk_weight.to(x.dtype)
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
qtype = self.experts[0].down_proj.qtype
if use_fuse_moe(x, qtype):
if getattr(self, "gates", None) is None:
gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts]
up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts]
down_addrs = [expert.down_proj.weight.data_ptr() for expert in self.experts]
gates = torch.tensor(gate_addrs, dtype=torch.uint64, device=x.device)
ups = torch.tensor(up_addrs, dtype=torch.uint64, device=x.device)
downs = torch.tensor(down_addrs, dtype=torch.uint64, device=x.device)
self.register_buffer("gates", gates, persistent=False)
self.register_buffer("ups", ups, persistent=False)
self.register_buffer("downs", downs, persistent=False)
import xe_linear
final_out = xe_linear.moe_forward_vec(
x, topk_ids, topk_weight, self.gates, self.ups, self.downs,
x.size(-1), self.experts[0].intermediate_size, qtype
)
else:
idxs = topk_ids.flatten().tolist()
outputs = []
for i in idxs:
expert = self.experts[i]
expert_out = expert(x)
outputs.append(expert_out)
outs = torch.cat(outputs, dim=0)
reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1)
final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
return final_out
def deepseek_moe_forward(self, hidden_states: torch.Tensor):
identity = hidden_states
orig_shape = hidden_states.shape
# IPEX-LLM OPT start: fuse grouped topk in gate forward
topk_idx, topk_weight = fuse_gate_forward(self.gate, hidden_states)
# IPEX-LLM OPT end
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if not self.training:
# IPEX-LLM OPT start: add special moe_infer implementation for decoding
if topk_idx.size(0) == 1 and self.ep_size == 1:
y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
else:
y = self.moe_infer(hidden_states, topk_idx, topk_weight)
y = y.view(*orig_shape)
# IPEX-LLM OPT end
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
return y