LLM: Support hybrid convert for DeepSeek V3/R1 (#12834)

LLM: Support hybrid convert for DeepSeek V3/R1
This commit is contained in:
Xiangyu Tian 2025-02-19 11:31:19 +08:00 committed by GitHub
parent 637543e135
commit 93c10be762
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 354 additions and 278 deletions

View file

@ -1,276 +0,0 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional, Tuple, Union
import warnings
import os
import torch
from torch import nn
import time
import argparse
import ipex_llm
from ipex_llm.transformers import AutoModelForCausalLM
from transformers import AutoTokenizer, GenerationConfig
from transformers.cache_utils import Cache, DynamicCache
deepseek_prompt = """
A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>. User: Question: If \( a > 1 \), then the sum of the real solutions of \( \sqrt{a} - \sqrt{a + x} = x \) is equal to:. Assistant: <think>
"""
def convert_forward_to_xpu(m, target_m, new_forward):
# print(m.__class__.__name__)
if m.__class__.__name__ == target_m:
bound_method = new_forward.__get__(m, m.__class__)
setattr(m, "forward", bound_method)
m = m.to(device="xpu", dtype=torch.float16)
for _, sub_m in m.named_children():
convert_forward_to_xpu(sub_m, target_m, new_forward)
def hybrid_MLP_forward(self, x):
x = x.to(device="xpu", dtype=torch.float16)
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj.to(device="cpu", dtype=torch.bfloat16)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Copied from modeling_deepseek.DeepseekV3Attention
def hybrid_DeepseekV3Attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
# ipex-llm modify: to xpu
hidden_states = hidden_states.to(device="xpu", dtype=torch.float16)
attention_mask = attention_mask.to(device="xpu", dtype=torch.float16)
position_ids = position_ids.to(device="xpu")
if past_key_value is not None:
past_key_value = past_key_value.to(device="xpu", dtype=torch.float16)
# end of ipex-llm modify
bsz, q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
k_nope, value_states = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
assert attention_mask is not None
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
# ipex-llm modify: to cpu
if attn_output is not None:
attn_output = attn_output.to(device="cpu", dtype=torch.bfloat16)
if attn_weights is not None:
attn_weights = attn_weights.to(device="cpu", dtype=torch.bfloat16)
if past_key_value is not None:
past_key_value = past_key_value.to(device="cpu", dtype=torch.bfloat16)
torch.xpu.empty_cache()
# end of ipex-llm modify
return attn_output, attn_weights, past_key_value
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict')
parser.add_argument('--load-path', type=str, default=None,
help='The path to load the low-bit model.')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
load_path = args.load_path
if load_path:
model = AutoModelForCausalLM.load_low_bit(load_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(load_path,
trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True,
use_cache=True)
tokenizer = AutoTokenizer.from_pretrained(model_path,
trust_remote_code=True)
#model = model.bfloat16()
convert_forward_to_xpu(model.model, "DeepseekV3Attention", hybrid_DeepseekV3Attention_forward)
convert_forward_to_xpu(model.model.layers[:3], "DeepseekV3MLP", hybrid_MLP_forward)
print(model)
print("load completed")
# model = BenchmarkWrapper(model, do_print=True)
# Generate predicted tokens
with torch.inference_mode():
prompt = deepseek_prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# ipex_llm model needs a warmup, then inference time can be accurate
output = model.generate(input_ids,
max_new_tokens=args.n_predict)
# start inference
st = time.time()
output = model.generate(input_ids,
max_new_tokens=args.n_predict)
torch.xpu.synchronize()
end = time.time()
output = output.cpu()
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Inference time: {end-st} s')
print('-'*20, 'Prompt', '-'*20)
print(prompt)
print('-'*20, 'Output', '-'*20)
print(output_str)

View file

@ -4,7 +4,7 @@
# From docker image: intelanalytics/ipex-llm-serving-xpu:2.2.0-b13
pip install transformers==4.48.3 accelerate trl==0.12.0
python generate.py --load-path /llm/models/DeepSeek-R1-int4-1/ --n-predict 128
python generate_hybrid.py --load-path /llm/models/DeepSeek-R1-int4-1/ --n-predict 128
```
Sample output:

View file

@ -0,0 +1,92 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional, Tuple, Union
import warnings
import os
import torch
from torch import nn
import time
import argparse
import ipex_llm
from ipex_llm.transformers import convert_model_hybrid
from ipex_llm.transformers import AutoModelForCausalLM
from transformers import AutoTokenizer, GenerationConfig
from transformers.cache_utils import Cache, DynamicCache
PROMPT_FORMAT = """
A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>.
User: {prompt}.
Assistant: <think>
"""
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default="If \( a > 1 \), then the sum of the real solutions of \( \sqrt{a} - \sqrt{a + x} = x \) is equal to:",
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict')
parser.add_argument('--load-path', type=str, default=None,
help='The path to load the low-bit model.')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
load_path = args.load_path
if load_path:
model = AutoModelForCausalLM.load_low_bit(load_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(load_path,
trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True,
use_cache=True)
tokenizer = AutoTokenizer.from_pretrained(model_path,
trust_remote_code=True)
#model = model.bfloat16()
model = convert_model_hybrid(model)
print(model)
# Generate predicted tokens
with torch.inference_mode():
prompt = PROMPT_FORMAT.format(prompt=args.prompt)
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# ipex_llm model needs a warmup, then inference time can be accurate
output = model.generate(input_ids,
max_new_tokens=args.n_predict)
# start inference
st = time.time()
output = model.generate(input_ids,
max_new_tokens=args.n_predict)
torch.xpu.synchronize()
end = time.time()
output = output.cpu()
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Inference time: {end-st} s')
print('-'*20, 'Prompt', '-'*20)
print(prompt)
print('-'*20, 'Output', '-'*20)
print(output_str)

View file

@ -15,7 +15,7 @@
#
from .convert import ggml_convert_low_bit, get_enable_ipex
from .convert import ggml_convert_low_bit, get_enable_ipex, convert_model_hybrid
from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, \
AutoModelForSpeechSeq2Seq, AutoModelForQuestionAnswering, \
AutoModelForSequenceClassification, AutoModelForMaskedLM, \

View file

@ -2025,3 +2025,30 @@ def _optimize_post(model):
_optimize_post(model.language_model)
return model
def convert_forward_to_xpu(m, target_m, new_forward):
# print(m.__class__.__name__)
if m.__class__ == target_m:
bound_method = new_forward.__get__(m, m.__class__)
setattr(m, "forward", bound_method)
m = m.to(device="xpu", dtype=torch.float16)
for _, sub_m in m.named_children():
convert_forward_to_xpu(sub_m, target_m, new_forward)
def convert_model_hybrid(model):
if model.config.model_type == "deepseek_v3":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.deepseek_v3 import (
hybrid_DeepseekV3Attention_forward,
hybrid_DeepseekV3MLP_forward,
)
first_k_dense_replace = model.config.first_k_dense_replace
convert_forward_to_xpu(model, module.DeepseekV3Attention,
hybrid_DeepseekV3Attention_forward)
convert_forward_to_xpu(model.model.layers[:first_k_dense_replace], module.DeepseekV3MLP,
hybrid_DeepseekV3MLP_forward)
return model

View file

@ -0,0 +1,233 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/llama/modeling_llama.py
# which is licensed under Apache License 2.0:
#
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import warnings
import torch
from torch import nn
from transformers.cache_utils import Cache
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
from ipex_llm.utils.common import invalidInputError
def hybrid_DeepseekV3MLP_forward(self, x):
x = x.to(device="xpu", dtype=torch.float16)
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj.to(device="cpu", dtype=torch.bfloat16)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache. # noqa
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Copied from modeling_deepseek.DeepseekV3Attention
def DeepseekV3Attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" # noqa
)
bsz, q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
k_nope, value_states = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError( # noqa
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " # noqa
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " # noqa
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError( # noqa
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" # noqa
f" {attn_weights.size()}"
)
assert attention_mask is not None # noqa
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError( # noqa
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" # noqa
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
raise ValueError( # noqa
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" # noqa
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def hybrid_DeepseekV3Attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
xpu_device: str = "xpu",
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
hidden_states = hidden_states.to(device="xpu", dtype=torch.float16)
attention_mask = attention_mask.to(device="xpu", dtype=torch.float16)
position_ids = position_ids.to(device="xpu")
if past_key_value is not None:
past_key_value = past_key_value.to(device="xpu", dtype=torch.float16)
attn_output, attn_weights, past_key_value = DeepseekV3Attention_forward(
self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs # noqa
)
if attn_output is not None:
attn_output = attn_output.to(device="cpu", dtype=torch.bfloat16)
if attn_weights is not None:
attn_weights = attn_weights.to(device="cpu", dtype=torch.bfloat16)
if past_key_value is not None:
past_key_value = past_key_value.to(device="cpu", dtype=torch.bfloat16)
torch.xpu.empty_cache()
return attn_output, attn_weights, past_key_value