optimize attention part of moonlight-14B-A3B (#12886)
This commit is contained in:
		
							parent
							
								
									dd30d12cb6
								
							
						
					
					
						commit
						ab3fc66eb7
					
				
					 4 changed files with 335 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -1070,7 +1070,9 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
        model.apply(pre_register_inv_freq)
 | 
			
		||||
    elif model.config.model_type == "multi_modality":
 | 
			
		||||
        _optimize_pre(model.language_model)
 | 
			
		||||
 | 
			
		||||
    elif model.config.model_type == "deepseek_v3" and model.config.hidden_size == 2048:
 | 
			
		||||
        from ipex_llm.transformers.models.deepseek import padding_mla_v_hd
 | 
			
		||||
        model.apply(padding_mla_v_hd)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -2023,6 +2025,15 @@ def _optimize_post(model):
 | 
			
		|||
 | 
			
		||||
        # llm
 | 
			
		||||
        _optimize_post(model.language_model)
 | 
			
		||||
    elif model.config.model_type == "deepseek_v3" and model.config.hidden_size == 2048:
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.common import rms_norm_forward
 | 
			
		||||
        from ipex_llm.transformers.models.deepseek import deepseek_model_forward
 | 
			
		||||
        from ipex_llm.transformers.models.deepseek import deepseek_attention_forward
 | 
			
		||||
        convert_forward(model, module.DeepseekV3RMSNorm, rms_norm_forward)
 | 
			
		||||
        convert_forward(model, module.DeepseekV3Model, deepseek_model_forward)
 | 
			
		||||
        convert_forward(model, module.DeepseekV3Attention, deepseek_attention_forward)
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -95,6 +95,33 @@ def padding_attention_hd_base(module: torch.nn.Module, attention_class,
 | 
			
		|||
        module.old_head_dim = old_head_dim
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def padding_mla_v_hd_base(module: torch.nn.Module, attention_class):
 | 
			
		||||
    if (
 | 
			
		||||
        isinstance(attention_class, str) and module.__class__.__name__ == attention_class
 | 
			
		||||
        or not isinstance(attention_class, str) and isinstance(module, attention_class)
 | 
			
		||||
    ):
 | 
			
		||||
        k_head_dim = module.q_head_dim
 | 
			
		||||
        v_head_dim = module.v_head_dim
 | 
			
		||||
        if v_head_dim < k_head_dim:
 | 
			
		||||
            kv_b_proj = module.kv_b_proj
 | 
			
		||||
            w = kv_b_proj.weight.data.view(module.num_heads,
 | 
			
		||||
                                           module.qk_nope_head_dim + module.v_head_dim,
 | 
			
		||||
                                           module.kv_lora_rank)
 | 
			
		||||
            k_w, v_w = w.split([module.qk_nope_head_dim, module.v_head_dim], dim=1)
 | 
			
		||||
            new_v_w = torch.zeros([module.num_heads, k_head_dim, module.kv_lora_rank],
 | 
			
		||||
                                  dtype=v_w.dtype, device=v_w.device)
 | 
			
		||||
            new_v_w[:, :v_head_dim, :] = v_w
 | 
			
		||||
            new_w = torch.cat([k_w, new_v_w], dim=1).view(-1, module.kv_lora_rank)
 | 
			
		||||
 | 
			
		||||
            new_kv_b_proj = torch.nn.Linear(0, 0, bias=False,
 | 
			
		||||
                                            dtype=new_w.dtype, device=new_w.device)
 | 
			
		||||
            new_kv_b_proj.in_features = new_w.size(1)
 | 
			
		||||
            new_kv_b_proj.out_features = new_w.size(0)
 | 
			
		||||
            new_kv_b_proj.weight = torch.nn.Parameter(new_w, False)
 | 
			
		||||
 | 
			
		||||
            module.kv_b_proj = new_kv_b_proj
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def padding_states_hd(states: torch.Tensor, old_head_dim: int, new_head_dim: int):
 | 
			
		||||
    bsz, num_heads, seq_len, head_dim = states.size()
 | 
			
		||||
    if head_dim == old_head_dim and old_head_dim < new_head_dim:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										271
									
								
								python/llm/src/ipex_llm/transformers/models/deepseek.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										271
									
								
								python/llm/src/ipex_llm/transformers/models/deepseek.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,271 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
 | 
			
		||||
# which is licensed under Apache License 2.0:
 | 
			
		||||
#
 | 
			
		||||
# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple, List, Union
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 | 
			
		||||
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicNormalCache
 | 
			
		||||
from ipex_llm.transformers.models.common import padding_mla_v_hd_base
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from ipex_llm.transformers.models.utils import rotate_half
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def padding_mla_v_hd(module: torch.nn.Module):
 | 
			
		||||
    padding_mla_v_hd_base(module, "DeepseekV3Attention")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def deepseek_model_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids: torch.LongTensor = None,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    use_cache: Optional[bool] = None,
 | 
			
		||||
    output_attentions: Optional[bool] = None,
 | 
			
		||||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
    output_attentions = (
 | 
			
		||||
        output_attentions if output_attentions is not None
 | 
			
		||||
        else self.config.output_attentions
 | 
			
		||||
    )
 | 
			
		||||
    output_hidden_states = (
 | 
			
		||||
        output_hidden_states if output_hidden_states is not None
 | 
			
		||||
        else self.config.output_hidden_states
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
 | 
			
		||||
    return_dict = (
 | 
			
		||||
        return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # retrieve input_ids and inputs_embeds
 | 
			
		||||
    invalidInputError((input_ids is None) ^ (inputs_embeds is None),
 | 
			
		||||
                      "You cannot specify both input_ids and inputs_embeds at the same time, "
 | 
			
		||||
                      "and must specify either one")
 | 
			
		||||
 | 
			
		||||
    if inputs_embeds is None:
 | 
			
		||||
        inputs_embeds = self.embed_tokens(input_ids)
 | 
			
		||||
 | 
			
		||||
    batch_size, seq_length = inputs_embeds.shape[:2]
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT start: kv cache
 | 
			
		||||
    past_key_values_length = 0
 | 
			
		||||
    use_cache = True if inputs_embeds.device.type == "xpu" else use_cache
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if not isinstance(past_key_values, DynamicNormalCache):
 | 
			
		||||
            past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
 | 
			
		||||
        past_key_values_length = past_key_values.get_usable_length(seq_length)
 | 
			
		||||
    # IPEX-LLM OPT end: kv cache
 | 
			
		||||
 | 
			
		||||
    if position_ids is None:
 | 
			
		||||
        position_ids = torch.arange(
 | 
			
		||||
            past_key_values_length,
 | 
			
		||||
            seq_length + past_key_values_length,
 | 
			
		||||
            dtype=torch.long,
 | 
			
		||||
            device=inputs_embeds.device,
 | 
			
		||||
        )
 | 
			
		||||
        position_ids = position_ids.unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT start: fuse rope
 | 
			
		||||
    if inputs_embeds.device.type == "xpu" and position_ids is not None:
 | 
			
		||||
        cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds,
 | 
			
		||||
                                                       seq_length + past_key_values_length)
 | 
			
		||||
        cos = cos[position_ids[0]].contiguous()
 | 
			
		||||
        sin = sin[position_ids[0]].contiguous()
 | 
			
		||||
        position_embeddings = (cos, sin)
 | 
			
		||||
    else:
 | 
			
		||||
        position_embeddings = None
 | 
			
		||||
    # IPEX-LLM OPT end: fuse rope
 | 
			
		||||
 | 
			
		||||
    # 4d mask is passed through the layers
 | 
			
		||||
    attention_mask = _prepare_4d_causal_attention_mask(
 | 
			
		||||
        attention_mask,
 | 
			
		||||
        (batch_size, seq_length),
 | 
			
		||||
        inputs_embeds,
 | 
			
		||||
        past_key_values_length,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # embed positions
 | 
			
		||||
    hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
    # decoder layers
 | 
			
		||||
    all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
    all_self_attns = () if output_attentions else None
 | 
			
		||||
    next_decoder_cache = None
 | 
			
		||||
 | 
			
		||||
    for decoder_layer in self.layers:
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        layer_outputs = decoder_layer(
 | 
			
		||||
            hidden_states,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            position_ids=position_ids,
 | 
			
		||||
            past_key_value=past_key_values,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            position_embeddings=position_embeddings,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = layer_outputs[0]
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
			
		||||
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            all_self_attns += (layer_outputs[1],)
 | 
			
		||||
 | 
			
		||||
    hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
    # add hidden states from the last decoder layer
 | 
			
		||||
    if output_hidden_states:
 | 
			
		||||
        all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
    next_cache = next_decoder_cache
 | 
			
		||||
    if not return_dict:
 | 
			
		||||
        return tuple(
 | 
			
		||||
            v
 | 
			
		||||
            for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
 | 
			
		||||
            if v is not None
 | 
			
		||||
        )
 | 
			
		||||
    return BaseModelOutputWithPast(
 | 
			
		||||
        last_hidden_state=hidden_states,
 | 
			
		||||
        past_key_values=next_cache,
 | 
			
		||||
        hidden_states=all_hidden_states,
 | 
			
		||||
        attentions=all_self_attns,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
 | 
			
		||||
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
 | 
			
		||||
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
 | 
			
		||||
 | 
			
		||||
    b, h, s, d = q.shape
 | 
			
		||||
    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
 | 
			
		||||
 | 
			
		||||
    b, h, s, d = k.shape
 | 
			
		||||
    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
 | 
			
		||||
 | 
			
		||||
    q_embed = (q * cos) + (rotate_half(q) * sin)
 | 
			
		||||
    k_embed = (k * cos) + (rotate_half(k) * sin)
 | 
			
		||||
    return q_embed, k_embed
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def deepseek_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_value: Optional[Cache] = None,
 | 
			
		||||
    output_attentions: bool = False,
 | 
			
		||||
    use_cache: bool = False,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
    if "padding_mask" in kwargs:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
            "Passing `padding_mask` is deprecated and will be removed in v4.37. "
 | 
			
		||||
            "Please make sure use `attention_mask` instead.`"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
 | 
			
		||||
    if self.q_lora_rank is None:
 | 
			
		||||
        q = self.q_proj(hidden_states)
 | 
			
		||||
    else:
 | 
			
		||||
        q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
 | 
			
		||||
    q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
 | 
			
		||||
    compressed_kv, k_pe = torch.split(
 | 
			
		||||
        compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
 | 
			
		||||
    )
 | 
			
		||||
    k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
 | 
			
		||||
    kv = (
 | 
			
		||||
        self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
 | 
			
		||||
        .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.q_head_dim)
 | 
			
		||||
        .transpose(1, 2)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    k_nope, value_states = torch.split(
 | 
			
		||||
        kv, [self.qk_nope_head_dim, self.q_head_dim], dim=-1
 | 
			
		||||
    )
 | 
			
		||||
    kv_seq_len = value_states.shape[-2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
			
		||||
 | 
			
		||||
    position_embeddings = kwargs.get("position_embeddings", None)
 | 
			
		||||
    if position_embeddings is not None:
 | 
			
		||||
        query_states = q
 | 
			
		||||
        key_states = torch.cat(
 | 
			
		||||
            [k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
 | 
			
		||||
            dim=-1
 | 
			
		||||
        )
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        cos, sin = position_embeddings
 | 
			
		||||
        xe_addons.rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:],
 | 
			
		||||
                                                 key_states[:, :, :, self.qk_nope_head_dim:],
 | 
			
		||||
                                                 cos, sin, True)
 | 
			
		||||
    else:
 | 
			
		||||
        q_nope, q_pe = torch.split(
 | 
			
		||||
            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
 | 
			
		||||
        )
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
 | 
			
		||||
 | 
			
		||||
        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
 | 
			
		||||
        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
 | 
			
		||||
        query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
 | 
			
		||||
 | 
			
		||||
        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
 | 
			
		||||
        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
 | 
			
		||||
        key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                         self.layer_idx, None)
 | 
			
		||||
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    attn_output = scaled_dot_product_attention(
 | 
			
		||||
        query_states, key_states, value_states,
 | 
			
		||||
        attention_mask, q_len == kv_seq_len, self.softmax_scale
 | 
			
		||||
    )
 | 
			
		||||
    attn_output = attn_output[:, :, :, :self.v_head_dim]
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
 | 
			
		||||
 | 
			
		||||
    attn_output = self.o_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
    if not output_attentions:
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
 | 
			
		||||
    return attn_output, attn_weights, past_key_value
 | 
			
		||||
| 
						 | 
				
			
			@ -1,3 +1,25 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://hf-mirror.com/openbmb/MiniCPM3-4B/blob/main/modeling_minicpm.py
 | 
			
		||||
# which is licensed under Apache License 2.0:
 | 
			
		||||
#
 | 
			
		||||
# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -122,9 +144,6 @@ def minicpm3_attention_forward(
 | 
			
		|||
 | 
			
		||||
    q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
 | 
			
		||||
    q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
 | 
			
		||||
    q_nope, q_pe = torch.split(
 | 
			
		||||
        q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
 | 
			
		||||
    compressed_kv, k_pe = torch.split(
 | 
			
		||||
| 
						 | 
				
			
			@ -169,6 +188,9 @@ def minicpm3_attention_forward(
 | 
			
		|||
        else:
 | 
			
		||||
            invalidInputError(f"unknown rope method: {self.rotary_emb.__class__.__name__}")
 | 
			
		||||
    else:
 | 
			
		||||
        q_nope, q_pe = torch.split(
 | 
			
		||||
            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
 | 
			
		||||
        )
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue