* add quantize_linear & linear_forward * add moe_group_topk * rotary_two_with_cache_inplaced * fix code style * update related models
346 lines
13 KiB
Python
346 lines
13 KiB
Python
#
|
|
# Copyright 2016 The BigDL Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# Some parts of this file is adapted from
|
|
# https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
|
|
# which is licensed under Apache License 2.0:
|
|
#
|
|
# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE
|
|
#
|
|
|
|
import torch
|
|
import warnings
|
|
|
|
from typing import Optional, Tuple, List, Union
|
|
from transformers.cache_utils import Cache
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
|
|
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
|
from ipex_llm.transformers.kv import DynamicNormalCache
|
|
from ipex_llm.transformers.models.common import padding_mla_v_hd_base
|
|
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
|
from ipex_llm.transformers.models.utils import rotate_half
|
|
|
|
|
|
def padding_mla_v_hd(module: torch.nn.Module):
|
|
padding_mla_v_hd_base(module, "DeepseekV3Attention")
|
|
|
|
|
|
def deepseek_model_forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
output_attentions = (
|
|
output_attentions if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
# retrieve input_ids and inputs_embeds
|
|
invalidInputError((input_ids is None) ^ (inputs_embeds is None),
|
|
"You cannot specify both input_ids and inputs_embeds at the same time, "
|
|
"and must specify either one")
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
|
|
|
# IPEX-LLM OPT start: kv cache
|
|
past_key_values_length = 0
|
|
use_cache = True if inputs_embeds.device.type == "xpu" else use_cache
|
|
if use_cache:
|
|
if not isinstance(past_key_values, DynamicNormalCache):
|
|
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
|
# IPEX-LLM OPT end: kv cache
|
|
|
|
if position_ids is None:
|
|
position_ids = torch.arange(
|
|
past_key_values_length,
|
|
seq_length + past_key_values_length,
|
|
dtype=torch.long,
|
|
device=inputs_embeds.device,
|
|
)
|
|
position_ids = position_ids.unsqueeze(0)
|
|
|
|
# IPEX-LLM OPT start: fuse rope
|
|
if inputs_embeds.device.type == "xpu" and position_ids is not None:
|
|
cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds,
|
|
seq_length + past_key_values_length)
|
|
cos = cos[position_ids[0]].contiguous()
|
|
sin = sin[position_ids[0]].contiguous()
|
|
position_embeddings = (cos, sin)
|
|
else:
|
|
position_embeddings = None
|
|
# IPEX-LLM OPT end: fuse rope
|
|
|
|
# 4d mask is passed through the layers
|
|
attention_mask = _prepare_4d_causal_attention_mask(
|
|
attention_mask,
|
|
(batch_size, seq_length),
|
|
inputs_embeds,
|
|
past_key_values_length,
|
|
)
|
|
|
|
# embed positions
|
|
hidden_states = inputs_embeds
|
|
|
|
# decoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
next_decoder_cache = None
|
|
|
|
for decoder_layer in self.layers:
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
position_embeddings=position_embeddings,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if use_cache:
|
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
# add hidden states from the last decoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
next_cache = next_decoder_cache
|
|
if not return_dict:
|
|
return tuple(
|
|
v
|
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
|
if v is not None
|
|
)
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=next_cache,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
)
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
|
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
|
|
|
b, h, s, d = q.shape
|
|
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
|
|
|
b, h, s, d = k.shape
|
|
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
|
|
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
return q_embed, k_embed
|
|
|
|
|
|
def deepseek_attention_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
**kwargs,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
if "padding_mask" in kwargs:
|
|
warnings.warn(
|
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
|
"Please make sure use `attention_mask` instead.`"
|
|
)
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
if self.q_lora_rank is None:
|
|
q = self.q_proj(hidden_states)
|
|
else:
|
|
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
|
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
|
|
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
|
compressed_kv, k_pe = torch.split(
|
|
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
|
)
|
|
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
|
kv = (
|
|
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
|
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.q_head_dim)
|
|
.transpose(1, 2)
|
|
)
|
|
|
|
k_nope, value_states = torch.split(
|
|
kv, [self.qk_nope_head_dim, self.q_head_dim], dim=-1
|
|
)
|
|
kv_seq_len = value_states.shape[-2]
|
|
if past_key_value is not None:
|
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
position_embeddings = kwargs.get("position_embeddings", None)
|
|
if position_embeddings is not None:
|
|
query_states = q
|
|
key_states = torch.cat(
|
|
[k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
|
|
dim=-1
|
|
)
|
|
cos, sin = position_embeddings
|
|
from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced
|
|
rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:],
|
|
key_states[:, :, :, self.qk_nope_head_dim:],
|
|
cos, sin, True)
|
|
else:
|
|
q_nope, q_pe = torch.split(
|
|
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
|
)
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
|
|
|
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
|
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
|
query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
|
|
|
|
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
|
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
|
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
|
|
|
|
if past_key_value is not None:
|
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
|
self.layer_idx, None)
|
|
|
|
attn_weights = None
|
|
attn_output = scaled_dot_product_attention(
|
|
query_states, key_states, value_states,
|
|
attention_mask, q_len == kv_seq_len, self.softmax_scale
|
|
)
|
|
attn_output = attn_output[:, :, :, :self.v_head_dim]
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
def fuse_gate_forward(self, x: torch.Tensor):
|
|
if x.device.type == "xpu" and x.dtype in [torch.float, torch.half]:
|
|
x = x.view(-1, x.size(-1))
|
|
logits = torch.nn.functional.linear(
|
|
x.type(torch.float32), self.weight.type(torch.float32), None
|
|
)
|
|
scores = logits.sigmoid()
|
|
|
|
from ipex_llm.transformers.models.common import moe_group_topk
|
|
topk_idx, topk_weight = moe_group_topk(
|
|
scores, self.e_score_correction_bias,
|
|
self.n_group, self.topk_group, self.top_k,
|
|
self.norm_topk_prob, self.routed_scaling_factor
|
|
)
|
|
else:
|
|
topk_idx, topk_weight = self(x)
|
|
return topk_idx, topk_weight.to(x.dtype)
|
|
|
|
|
|
def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
|
|
if (
|
|
x.device.type == "xpu"
|
|
and x.dtype in [torch.float, torch.half]
|
|
and self.experts[0].down_proj.qtype == 2
|
|
):
|
|
if getattr(self, "gates", None) is None:
|
|
gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts]
|
|
up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts]
|
|
down_addrs = [expert.down_proj.weight.data_ptr() for expert in self.experts]
|
|
gates = torch.tensor(gate_addrs, dtype=torch.uint64, device=x.device)
|
|
ups = torch.tensor(up_addrs, dtype=torch.uint64, device=x.device)
|
|
downs = torch.tensor(down_addrs, dtype=torch.uint64, device=x.device)
|
|
self.register_buffer("gates", gates, persistent=False)
|
|
self.register_buffer("ups", ups, persistent=False)
|
|
self.register_buffer("downs", downs, persistent=False)
|
|
|
|
import xe_linear
|
|
final_out = xe_linear.moe_forward_vec(
|
|
x, topk_ids, topk_weight, self.gates, self.ups, self.downs,
|
|
x.size(-1), self.experts[0].intermediate_size, 2
|
|
)
|
|
else:
|
|
idxs = topk_ids.flatten().tolist()
|
|
outputs = []
|
|
for i in idxs:
|
|
expert = self.experts[i]
|
|
expert_out = expert(x)
|
|
outputs.append(expert_out)
|
|
outs = torch.cat(outputs, dim=0)
|
|
reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1)
|
|
final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
|
|
return final_out
|
|
|
|
|
|
def deepseek_moe_forward(self, hidden_states: torch.Tensor):
|
|
identity = hidden_states
|
|
orig_shape = hidden_states.shape
|
|
# IPEX-LLM OPT start: fuse grouped topk in gate forward
|
|
topk_idx, topk_weight = fuse_gate_forward(self.gate, hidden_states)
|
|
# IPEX-LLM OPT end
|
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
|
flat_topk_idx = topk_idx.view(-1)
|
|
if not self.training:
|
|
# IPEX-LLM OPT start: add special moe_infer implementation for decoding
|
|
if topk_idx.size(0) == 1 and self.ep_size == 1:
|
|
y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
|
|
else:
|
|
y = self.moe_infer(hidden_states, topk_idx, topk_weight)
|
|
y = y.view(*orig_shape)
|
|
# IPEX-LLM OPT end
|
|
if self.config.n_shared_experts is not None:
|
|
y = y + self.shared_experts(identity)
|
|
return y
|