refactor to reduce old rope usage (#12219)
This commit is contained in:
		
							parent
							
								
									667f0db466
								
							
						
					
					
						commit
						324bcb057e
					
				
					 4 changed files with 31 additions and 161 deletions
				
			
		| 
						 | 
					@ -36,22 +36,14 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
from typing import List, Optional, Tuple, Union
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.utils.checkpoint
 | 
					 | 
				
			||||||
from torch import nn
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, \
 | 
					from typing import Optional, Tuple
 | 
				
			||||||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
					from ipex_llm.transformers.models.common import attention_softmax
 | 
				
			||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
 | 
					from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, should_use_fuse_rope
 | 
				
			||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
					from ipex_llm.transformers.models.utils import update_past_key_value
 | 
				
			||||||
from ipex_llm.utils.common import log4Error
 | 
					from ipex_llm.utils.common import log4Error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def aquila_attention_forward(
 | 
					def aquila_attention_forward(
 | 
				
			||||||
    self,
 | 
					    self,
 | 
				
			||||||
| 
						 | 
					@ -75,58 +67,27 @@ def aquila_attention_forward(
 | 
				
			||||||
        .transpose(1, 2)
 | 
					        .transpose(1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    kv_seq_len = key_states.shape[-2]
 | 
					    kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
    enough_kv_room = True
 | 
					 | 
				
			||||||
    if past_key_value is not None:
 | 
					    if past_key_value is not None:
 | 
				
			||||||
        enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
 | 
					 | 
				
			||||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
					        kv_seq_len += past_key_value[0].shape[-2]
 | 
				
			||||||
    if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
 | 
					
 | 
				
			||||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
					    if should_use_fuse_rope(hidden_states, position_ids, self.training):
 | 
				
			||||||
                                                                     key_states,
 | 
					        import xe_addons
 | 
				
			||||||
                                                                     position_ids,
 | 
					        xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
				
			||||||
                                                                     "aquila")
 | 
					                                       query_states, key_states)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
					        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
				
			||||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
					        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
				
			||||||
                                                        cos, sin, position_ids, "aquila")
 | 
					                                                        cos, sin, position_ids, "aquila")
 | 
				
			||||||
    # [bsz, nh, t, hd]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if past_key_value is not None:
 | 
					    key_states, value_states = update_past_key_value(
 | 
				
			||||||
        # reuse k, v, self_attention
 | 
					        past_key_value, key_states, value_states,
 | 
				
			||||||
        cache_k = past_key_value[0]
 | 
					        kv_seq_len, False, hidden_states.device
 | 
				
			||||||
        cache_v = past_key_value[1]
 | 
					    )
 | 
				
			||||||
        if not enough_kv_room:
 | 
					 | 
				
			||||||
            # allocate new
 | 
					 | 
				
			||||||
            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
					 | 
				
			||||||
                                                       self.num_heads,  # Support GQA
 | 
					 | 
				
			||||||
                                                       self.head_dim,
 | 
					 | 
				
			||||||
                                                       cache_k.size(2),
 | 
					 | 
				
			||||||
                                                       kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
					 | 
				
			||||||
                                                       dtype=cache_k.dtype,
 | 
					 | 
				
			||||||
                                                       device=hidden_states.device)
 | 
					 | 
				
			||||||
            new_cache_k[:] = cache_k
 | 
					 | 
				
			||||||
            new_cache_v[:] = cache_v
 | 
					 | 
				
			||||||
            cache_k = new_cache_k
 | 
					 | 
				
			||||||
            cache_v = new_cache_v
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    elif use_cache:
 | 
					 | 
				
			||||||
        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					 | 
				
			||||||
        new_key_states, new_value_states = init_kv_cache(bsz,
 | 
					 | 
				
			||||||
                                                         self.num_heads,
 | 
					 | 
				
			||||||
                                                         self.head_dim,
 | 
					 | 
				
			||||||
                                                         kv_seq_len,
 | 
					 | 
				
			||||||
                                                         max_cache_length,
 | 
					 | 
				
			||||||
                                                         dtype=key_states.dtype,
 | 
					 | 
				
			||||||
                                                         device=hidden_states.device)
 | 
					 | 
				
			||||||
        new_key_states[:] = key_states
 | 
					 | 
				
			||||||
        new_value_states[:] = value_states
 | 
					 | 
				
			||||||
        key_states = new_key_states
 | 
					 | 
				
			||||||
        value_states = new_value_states
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
					    past_key_value = (key_states, value_states) if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
					    attn_weights = torch.matmul(query_states,
 | 
				
			||||||
 | 
					                                key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_weights = torch.clamp(attn_weights, min=-1024., max=1024.)
 | 
					    attn_weights = torch.clamp(attn_weights, min=-1024., max=1024.)
 | 
				
			||||||
    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 | 
					    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 | 
				
			||||||
| 
						 | 
					@ -148,8 +109,7 @@ def aquila_attention_forward(
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # upcast attention to fp32
 | 
					    # upcast attention to fp32
 | 
				
			||||||
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)\
 | 
					    attn_weights = attention_softmax(attn_weights, self.training)
 | 
				
			||||||
        .to(query_states.dtype)
 | 
					 | 
				
			||||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
					    attn_output = torch.matmul(attn_weights, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 | 
					    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,11 +34,11 @@
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from typing import Optional, Tuple
 | 
					from typing import Optional, Tuple
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
					from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
 | 
				
			||||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
					 | 
				
			||||||
    apply_rotary_pos_emb
 | 
					 | 
				
			||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
					from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
				
			||||||
from ipex_llm.transformers.models.llama import should_use_fuse_rope, repeat_kv
 | 
					from ipex_llm.transformers.models.llama import repeat_kv
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import update_past_key_value
 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
| 
						 | 
					@ -61,32 +61,9 @@ def decilm_attention_forward_4_35_2(
 | 
				
			||||||
    is_decode = past_key_value is not None
 | 
					    is_decode = past_key_value is not None
 | 
				
			||||||
    device = hidden_states.device
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
					    query_states = self.q_proj(hidden_states)
 | 
				
			||||||
    enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
 | 
					    key_states = self.k_proj(hidden_states)
 | 
				
			||||||
 | 
					    value_states = self.v_proj(hidden_states)
 | 
				
			||||||
    if self.config.pretraining_tp > 1:
 | 
					 | 
				
			||||||
        key_value_slicing = ((self.num_key_value_heads * self.head_dim) //
 | 
					 | 
				
			||||||
                             self.config.pretraining_tp)
 | 
					 | 
				
			||||||
        query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim)
 | 
					 | 
				
			||||||
                                                // self.config.pretraining_tp, dim=0)
 | 
					 | 
				
			||||||
        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
 | 
					 | 
				
			||||||
        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        query_states = [F.linear(hidden_states, query_slices[i])
 | 
					 | 
				
			||||||
                        for i in range(self.config.pretraining_tp)]
 | 
					 | 
				
			||||||
        query_states = torch.cat(query_states, dim=-1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        key_states = [F.linear(hidden_states, key_slices[i])
 | 
					 | 
				
			||||||
                      for i in range(self.config.pretraining_tp)]
 | 
					 | 
				
			||||||
        key_states = torch.cat(key_states, dim=-1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        value_states = [F.linear(hidden_states, value_slices[i])
 | 
					 | 
				
			||||||
                        for i in range(self.config.pretraining_tp)]
 | 
					 | 
				
			||||||
        value_states = torch.cat(value_states, dim=-1)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        query_states = self.q_proj(hidden_states)
 | 
					 | 
				
			||||||
        key_states = self.k_proj(hidden_states)
 | 
					 | 
				
			||||||
        value_states = self.v_proj(hidden_states)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    query_states = query_states.view(bsz, q_len,
 | 
					    query_states = query_states.view(bsz, q_len,
 | 
				
			||||||
                                     self.num_heads, self.head_dim).transpose(1, 2)
 | 
					                                     self.num_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
| 
						 | 
					@ -99,7 +76,7 @@ def decilm_attention_forward_4_35_2(
 | 
				
			||||||
    if past_key_value is not None:
 | 
					    if past_key_value is not None:
 | 
				
			||||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
					        kv_seq_len += past_key_value[0].shape[-2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if use_fuse_rope:
 | 
					    if should_use_fuse_rope(hidden_states, position_ids, self.training):
 | 
				
			||||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
					        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
				
			||||||
                                                                     key_states,
 | 
					                                                                     key_states,
 | 
				
			||||||
                                                                     position_ids,
 | 
					                                                                     position_ids,
 | 
				
			||||||
| 
						 | 
					@ -109,39 +86,10 @@ def decilm_attention_forward_4_35_2(
 | 
				
			||||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
					        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
				
			||||||
                                                        cos, sin, position_ids, "llama")
 | 
					                                                        cos, sin, position_ids, "llama")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if past_key_value is not None:
 | 
					    key_states, value_states = update_past_key_value(
 | 
				
			||||||
        # reuse k, v, self_attention
 | 
					        past_key_value, key_states, value_states,
 | 
				
			||||||
        cache_k = past_key_value[0]
 | 
					        kv_seq_len, False, device
 | 
				
			||||||
        cache_v = past_key_value[1]
 | 
					    )
 | 
				
			||||||
        if not enough_kv_room:
 | 
					 | 
				
			||||||
            # allocate new
 | 
					 | 
				
			||||||
            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
					 | 
				
			||||||
                                                       self.num_key_value_heads,  # Support GQA
 | 
					 | 
				
			||||||
                                                       self.head_dim,
 | 
					 | 
				
			||||||
                                                       cache_k.size(2),
 | 
					 | 
				
			||||||
                                                       kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
					 | 
				
			||||||
                                                       dtype=cache_k.dtype,
 | 
					 | 
				
			||||||
                                                       device=device)
 | 
					 | 
				
			||||||
            new_cache_k[:] = cache_k
 | 
					 | 
				
			||||||
            new_cache_v[:] = cache_v
 | 
					 | 
				
			||||||
            cache_k = new_cache_k
 | 
					 | 
				
			||||||
            cache_v = new_cache_v
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    elif use_cache:
 | 
					 | 
				
			||||||
        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					 | 
				
			||||||
        new_key_states, new_value_states = init_kv_cache(bsz,
 | 
					 | 
				
			||||||
                                                         self.num_key_value_heads,
 | 
					 | 
				
			||||||
                                                         self.head_dim,
 | 
					 | 
				
			||||||
                                                         kv_seq_len,
 | 
					 | 
				
			||||||
                                                         max_cache_length,
 | 
					 | 
				
			||||||
                                                         dtype=key_states.dtype,
 | 
					 | 
				
			||||||
                                                         device=device)
 | 
					 | 
				
			||||||
        new_key_states[:] = key_states
 | 
					 | 
				
			||||||
        new_value_states[:] = value_states
 | 
					 | 
				
			||||||
        key_states = new_key_states
 | 
					 | 
				
			||||||
        value_states = new_value_states
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
					    past_key_value = (key_states, value_states) if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -167,14 +115,8 @@ def decilm_attention_forward_4_35_2(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
 | 
					        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if self.config.pretraining_tp > 1:
 | 
					    attn_output = attn_output.to(hidden_states.dtype)
 | 
				
			||||||
        attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
 | 
					    attn_output = self.o_proj(attn_output)
 | 
				
			||||||
        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp,
 | 
					 | 
				
			||||||
                                                 dim=1)
 | 
					 | 
				
			||||||
        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i])
 | 
					 | 
				
			||||||
                           for i in range(self.config.pretraining_tp)])
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        attn_output = self.o_proj(attn_output)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not output_attentions:
 | 
					    if not output_attentions:
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -46,16 +46,14 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_
 | 
				
			||||||
    get_compresskv_attn_mask
 | 
					    get_compresskv_attn_mask
 | 
				
			||||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
					from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
				
			||||||
    apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
 | 
					    apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
 | 
				
			||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
					 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
 | 
					from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
 | 
				
			||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
 | 
					from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path, get_q_proj_or_qkv_proj
 | 
					from ipex_llm.transformers.models.utils import use_decoding_fast_path, get_q_proj_or_qkv_proj
 | 
				
			||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
					from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
				
			||||||
from transformers.models.llama.modeling_llama import LlamaModel, LlamaAttention
 | 
					from transformers.models.llama.modeling_llama import LlamaModel, LlamaAttention
 | 
				
			||||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS, FP4
 | 
					from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS, FP4
 | 
				
			||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
					 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
from ipex_llm.transformers.models.common import merge_qkv_base, fuse_mlp_base
 | 
					from ipex_llm.transformers.models.common import merge_qkv_base
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    from transformers.cache_utils import Cache, DynamicCache
 | 
					    from transformers.cache_utils import Cache, DynamicCache
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,38 +37,8 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
""" PyTorch Phixtral model."""
 | 
					""" PyTorch Phixtral model."""
 | 
				
			||||||
import math
 | 
					 | 
				
			||||||
from typing import Optional, Tuple
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch import nn
 | 
					 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
					 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					 | 
				
			||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
					 | 
				
			||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb,\
 | 
					 | 
				
			||||||
    apply_rotary_pos_emb_no_cache_xpu, is_enough_kv_cache_room_4_36
 | 
					 | 
				
			||||||
from ipex_llm.transformers.models.mistral import should_use_fuse_rope
 | 
					 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_flash_attention
 | 
					 | 
				
			||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
 | 
					 | 
				
			||||||
    The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
 | 
					 | 
				
			||||||
    to (batch, num_attention_heads, seqlen, head_dim)
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
 | 
					 | 
				
			||||||
    if n_rep == 1:
 | 
					 | 
				
			||||||
        return hidden_states
 | 
					 | 
				
			||||||
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
 | 
					 | 
				
			||||||
                                                           n_rep, slen, head_dim)
 | 
					 | 
				
			||||||
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def phixtral_moeblock_forward(self, hidden_states: torch.Tensor):
 | 
					def phixtral_moeblock_forward(self, hidden_states: torch.Tensor):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue