optimize attention part of moonlight-14B-A3B (#12886)
This commit is contained in:
parent
dd30d12cb6
commit
ab3fc66eb7
4 changed files with 335 additions and 4 deletions
|
|
@ -1070,7 +1070,9 @@ def _optimize_pre(model, qtype=None):
|
|||
model.apply(pre_register_inv_freq)
|
||||
elif model.config.model_type == "multi_modality":
|
||||
_optimize_pre(model.language_model)
|
||||
|
||||
elif model.config.model_type == "deepseek_v3" and model.config.hidden_size == 2048:
|
||||
from ipex_llm.transformers.models.deepseek import padding_mla_v_hd
|
||||
model.apply(padding_mla_v_hd)
|
||||
return model
|
||||
|
||||
|
||||
|
|
@ -2023,6 +2025,15 @@ def _optimize_post(model):
|
|||
|
||||
# llm
|
||||
_optimize_post(model.language_model)
|
||||
elif model.config.model_type == "deepseek_v3" and model.config.hidden_size == 2048:
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||
from ipex_llm.transformers.models.deepseek import deepseek_model_forward
|
||||
from ipex_llm.transformers.models.deepseek import deepseek_attention_forward
|
||||
convert_forward(model, module.DeepseekV3RMSNorm, rms_norm_forward)
|
||||
convert_forward(model, module.DeepseekV3Model, deepseek_model_forward)
|
||||
convert_forward(model, module.DeepseekV3Attention, deepseek_attention_forward)
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
|||
|
|
@ -95,6 +95,33 @@ def padding_attention_hd_base(module: torch.nn.Module, attention_class,
|
|||
module.old_head_dim = old_head_dim
|
||||
|
||||
|
||||
def padding_mla_v_hd_base(module: torch.nn.Module, attention_class):
|
||||
if (
|
||||
isinstance(attention_class, str) and module.__class__.__name__ == attention_class
|
||||
or not isinstance(attention_class, str) and isinstance(module, attention_class)
|
||||
):
|
||||
k_head_dim = module.q_head_dim
|
||||
v_head_dim = module.v_head_dim
|
||||
if v_head_dim < k_head_dim:
|
||||
kv_b_proj = module.kv_b_proj
|
||||
w = kv_b_proj.weight.data.view(module.num_heads,
|
||||
module.qk_nope_head_dim + module.v_head_dim,
|
||||
module.kv_lora_rank)
|
||||
k_w, v_w = w.split([module.qk_nope_head_dim, module.v_head_dim], dim=1)
|
||||
new_v_w = torch.zeros([module.num_heads, k_head_dim, module.kv_lora_rank],
|
||||
dtype=v_w.dtype, device=v_w.device)
|
||||
new_v_w[:, :v_head_dim, :] = v_w
|
||||
new_w = torch.cat([k_w, new_v_w], dim=1).view(-1, module.kv_lora_rank)
|
||||
|
||||
new_kv_b_proj = torch.nn.Linear(0, 0, bias=False,
|
||||
dtype=new_w.dtype, device=new_w.device)
|
||||
new_kv_b_proj.in_features = new_w.size(1)
|
||||
new_kv_b_proj.out_features = new_w.size(0)
|
||||
new_kv_b_proj.weight = torch.nn.Parameter(new_w, False)
|
||||
|
||||
module.kv_b_proj = new_kv_b_proj
|
||||
|
||||
|
||||
def padding_states_hd(states: torch.Tensor, old_head_dim: int, new_head_dim: int):
|
||||
bsz, num_heads, seq_len, head_dim = states.size()
|
||||
if head_dim == old_head_dim and old_head_dim < new_head_dim:
|
||||
|
|
|
|||
271
python/llm/src/ipex_llm/transformers/models/deepseek.py
Normal file
271
python/llm/src/ipex_llm/transformers/models/deepseek.py
Normal file
|
|
@ -0,0 +1,271 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
|
||||
# which is licensed under Apache License 2.0:
|
||||
#
|
||||
# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE
|
||||
#
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
from typing import Optional, Tuple, List, Union
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
from ipex_llm.transformers.kv import DynamicNormalCache
|
||||
from ipex_llm.transformers.models.common import padding_mla_v_hd_base
|
||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||
from ipex_llm.transformers.models.utils import rotate_half
|
||||
|
||||
|
||||
def padding_mla_v_hd(module: torch.nn.Module):
|
||||
padding_mla_v_hd_base(module, "DeepseekV3Attention")
|
||||
|
||||
|
||||
def deepseek_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
invalidInputError((input_ids is None) ^ (inputs_embeds is None),
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, "
|
||||
"and must specify either one")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||
|
||||
# IPEX-LLM OPT start: kv cache
|
||||
past_key_values_length = 0
|
||||
use_cache = True if inputs_embeds.device.type == "xpu" else use_cache
|
||||
if use_cache:
|
||||
if not isinstance(past_key_values, DynamicNormalCache):
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
# IPEX-LLM OPT end: kv cache
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
# IPEX-LLM OPT start: fuse rope
|
||||
if inputs_embeds.device.type == "xpu" and position_ids is not None:
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds,
|
||||
seq_length + past_key_values_length)
|
||||
cos = cos[position_ids[0]].contiguous()
|
||||
sin = sin[position_ids[0]].contiguous()
|
||||
position_embeddings = (cos, sin)
|
||||
else:
|
||||
position_embeddings = None
|
||||
# IPEX-LLM OPT end: fuse rope
|
||||
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
|
||||
b, h, s, d = q.shape
|
||||
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||
|
||||
b, h, s, d = k.shape
|
||||
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def deepseek_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||
"Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
||||
kv = (
|
||||
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.q_head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
k_nope, value_states = torch.split(
|
||||
kv, [self.qk_nope_head_dim, self.q_head_dim], dim=-1
|
||||
)
|
||||
kv_seq_len = value_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
position_embeddings = kwargs.get("position_embeddings", None)
|
||||
if position_embeddings is not None:
|
||||
query_states = q
|
||||
key_states = torch.cat(
|
||||
[k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
|
||||
dim=-1
|
||||
)
|
||||
import xe_addons
|
||||
cos, sin = position_embeddings
|
||||
xe_addons.rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:],
|
||||
key_states[:, :, :, self.qk_nope_head_dim:],
|
||||
cos, sin, True)
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
||||
|
||||
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
||||
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
||||
query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
|
||||
|
||||
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
||||
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
||||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None)
|
||||
|
||||
attn_weights = None
|
||||
attn_output = scaled_dot_product_attention(
|
||||
query_states, key_states, value_states,
|
||||
attention_mask, q_len == kv_seq_len, self.softmax_scale
|
||||
)
|
||||
attn_output = attn_output[:, :, :, :self.v_head_dim]
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
|
@ -1,3 +1,25 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://hf-mirror.com/openbmb/MiniCPM3-4B/blob/main/modeling_minicpm.py
|
||||
# which is licensed under Apache License 2.0:
|
||||
#
|
||||
# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE
|
||||
#
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
|
|
@ -122,9 +144,6 @@ def minicpm3_attention_forward(
|
|||
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
|
|
@ -169,6 +188,9 @@ def minicpm3_attention_forward(
|
|||
else:
|
||||
invalidInputError(f"unknown rope method: {self.rotary_emb.__class__.__name__}")
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue