refactor yuan2 (#11235)
This commit is contained in:
parent
6be24fdd28
commit
ba27e750b1
2 changed files with 72 additions and 303 deletions
|
|
@ -682,39 +682,8 @@ def _optimize_pre(model):
|
||||||
model.lm_head.weight.data = norm_weight
|
model.lm_head.weight.data = norm_weight
|
||||||
# for yuan 2.0
|
# for yuan 2.0
|
||||||
if model.config.model_type == "yuan":
|
if model.config.model_type == "yuan":
|
||||||
def merge_qk_proj_func(module):
|
from ipex_llm.transformers.models.yuan import merge_qk
|
||||||
if "YuanAttention" in module.__class__.__name__:
|
model.apply(merge_qk)
|
||||||
q_weight = module.q_proj.weight.data
|
|
||||||
k_weight = module.k_proj.weight.data
|
|
||||||
num_heads = module.num_heads
|
|
||||||
head_dim = module.head_dim
|
|
||||||
hidden_size = module.hidden_size
|
|
||||||
|
|
||||||
weight_q = torch.cat([
|
|
||||||
q_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
|
|
||||||
k_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
|
|
||||||
], dim=0).view(num_heads * head_dim, hidden_size)
|
|
||||||
|
|
||||||
weight_k = torch.cat([
|
|
||||||
q_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
|
|
||||||
k_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
|
|
||||||
], dim=0).view(num_heads * head_dim, hidden_size)
|
|
||||||
|
|
||||||
merged_q_proj = torch.nn.Linear(0, 0, False)
|
|
||||||
merged_q_proj.weight = torch.nn.Parameter(weight_q, requires_grad=False)
|
|
||||||
merged_q_proj.in_features = hidden_size
|
|
||||||
merged_q_proj.out_features = num_heads * head_dim
|
|
||||||
module.merged_q_proj = merged_q_proj
|
|
||||||
|
|
||||||
merged_k_proj = torch.nn.Linear(0, 0, False)
|
|
||||||
merged_k_proj.weight = torch.nn.Parameter(weight_k, requires_grad=False)
|
|
||||||
merged_k_proj.in_features = hidden_size
|
|
||||||
merged_k_proj.out_features = num_heads * head_dim
|
|
||||||
module.merged_k_proj = merged_k_proj
|
|
||||||
|
|
||||||
del module.q_proj
|
|
||||||
del module.k_proj
|
|
||||||
model.apply(merge_qk_proj_func)
|
|
||||||
# for bge-large
|
# for bge-large
|
||||||
if model.config.model_type == 'bert' and (
|
if model.config.model_type == 'bert' and (
|
||||||
not model.config.is_decoder and
|
not model.config.is_decoder and
|
||||||
|
|
|
||||||
|
|
@ -20,32 +20,41 @@
|
||||||
# https://huggingface.co/IEITYuan/Yuan2-2B-hf/blob/7ab7b3c18eb8e5232ce2a3f720d4e6f4b53a2806/README.md#%E5%A3%B0%E6%98%8E%E4%B8%8E%E5%8D%8F%E8%AE%AEterms-and-conditions
|
# https://huggingface.co/IEITYuan/Yuan2-2B-hf/blob/7ab7b3c18eb8e5232ce2a3f720d4e6f4b53a2806/README.md#%E5%A3%B0%E6%98%8E%E4%B8%8E%E5%8D%8F%E8%AE%AEterms-and-conditions
|
||||||
#
|
#
|
||||||
|
|
||||||
import copy
|
|
||||||
import math
|
import math
|
||||||
from einops import rearrange
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
||||||
apply_rotary_pos_emb_cache_freq_xpu, mlp_fusion_check, fp16_fusion_check
|
mlp_fusion_check, fp16_fusion_check
|
||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
from ipex_llm.transformers.models.utils import SILU, update_past_key_value
|
||||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal
|
||||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, SILU
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
|
||||||
|
|
||||||
|
|
||||||
def should_use_fuse_rope(self, hidden_states, position_ids):
|
def merge_qk(module: torch.nn.Module):
|
||||||
use_fuse_rope = hidden_states.device.type == "xpu"
|
if "YuanAttention" in module.__class__.__name__:
|
||||||
use_fuse_rope = use_fuse_rope and not (self.training and hidden_states.requires_grad)
|
q_weight = module.q_proj.weight.data
|
||||||
use_fuse_rope = use_fuse_rope and position_ids is not None
|
k_weight = module.k_proj.weight.data
|
||||||
return use_fuse_rope
|
num_heads = module.num_heads
|
||||||
|
head_dim = module.head_dim
|
||||||
|
hidden_size = module.hidden_size
|
||||||
|
|
||||||
|
merged_qk_proj = torch.nn.Linear(0, 0, False)
|
||||||
|
weight = torch.cat([
|
||||||
|
q_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
|
||||||
|
k_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
|
||||||
|
q_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
|
||||||
|
k_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
|
||||||
|
], dim=0).view(num_heads * head_dim * 2, hidden_size)
|
||||||
|
merged_qk_proj.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
|
merged_qk_proj.in_features = hidden_size
|
||||||
|
merged_qk_proj.out_features = num_heads * head_dim * 2
|
||||||
|
module.qk_proj = merged_qk_proj
|
||||||
|
|
||||||
|
del module.q_proj
|
||||||
|
del module.k_proj
|
||||||
|
|
||||||
|
|
||||||
def yuan_localized_filtering_forward(
|
def yuan_localized_filtering_forward(
|
||||||
|
|
@ -142,43 +151,14 @@ def yuan_attention_forward(
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
if use_quantize_kv_cache(self.merged_q_proj, hidden_states):
|
|
||||||
forward_function = yuan_attention_forward_quantized
|
|
||||||
else:
|
|
||||||
forward_function = yuan_attention_forward_origin
|
|
||||||
return forward_function(
|
|
||||||
self=self,
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def yuan_attention_forward_quantized(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
before_hidden_states = None
|
|
||||||
is_first_step = False
|
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
|
||||||
|
|
||||||
invalidInputError(use_cache, "use_cache=True is needed")
|
invalidInputError(use_cache, "use_cache=True is needed")
|
||||||
invalidInputError(not self.use_shareqk, "use_shareqk is not supported for now")
|
invalidInputError(not self.use_shareqk, "use_shareqk is not supported for now")
|
||||||
|
|
||||||
if past_key_value is None:
|
if past_key_value is None:
|
||||||
is_first_step = True
|
|
||||||
if q_len >= 2:
|
if q_len >= 2:
|
||||||
before_hidden_states = hidden_states[:, -2:, :].transpose(0, 1).half()
|
before_hidden_states = hidden_states[:, -2:, :].transpose(0, 1).half()
|
||||||
else:
|
else:
|
||||||
|
|
@ -193,112 +173,75 @@ def yuan_attention_forward_quantized(
|
||||||
], dim=0)
|
], dim=0)
|
||||||
before_hidden_states = this_hidden_states[-2:, :, ]
|
before_hidden_states = this_hidden_states[-2:, :, ]
|
||||||
|
|
||||||
value_states = \
|
value_states = self.v_proj(hidden_states)
|
||||||
self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if is_first_step:
|
if past_key_value is None:
|
||||||
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
|
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
|
||||||
None, hidden_states.dtype)
|
None, hidden_states.dtype)
|
||||||
else:
|
else:
|
||||||
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
|
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
|
||||||
this_hidden_states, hidden_states.dtype)
|
this_hidden_states, hidden_states.dtype)
|
||||||
query_states = self.merged_q_proj(hidden_states)
|
|
||||||
key_states = self.merged_k_proj(hidden_states)
|
qk_states = self.qk_proj(hidden_states)
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
qk_states = qk_states.view(bsz, q_len, self.num_heads * 2, self.head_dim)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
qk_states = qk_states.transpose(1, 2)
|
||||||
|
query_states, key_states = torch.chunk(qk_states, 2, dim=1)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||||
if use_fuse_rope:
|
import xe_addons
|
||||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states,
|
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||||
key_states,
|
query_states, key_states)
|
||||||
sin, cos,
|
|
||||||
"yuan",
|
|
||||||
position_ids)
|
|
||||||
else:
|
else:
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states,
|
query_states, key_states = apply_rotary_pos_emb(query_states,
|
||||||
key_states,
|
key_states,
|
||||||
cos, sin,
|
cos, sin,
|
||||||
position_ids,
|
position_ids,
|
||||||
"yuan")
|
"yuan")
|
||||||
|
|
||||||
if past_key_value is None:
|
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
||||||
# should use origin attn here
|
use_quantize_kv = use_quantize_kv_cache(self.qk_proj, hidden_states)
|
||||||
attn_weights = torch.matmul(query_states,
|
key_states, value_states = update_past_key_value(
|
||||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
None if past_key_value is None else (past_key_value[0], past_key_value[1]),
|
||||||
|
key_states, value_states,
|
||||||
invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
|
kv_seq_len, use_quantize_kv, device
|
||||||
"Attention weights should be of size "
|
)
|
||||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, "
|
past_key_value = (key_states, value_states, before_hidden_states) if use_cache else None
|
||||||
f"but is {attn_weights.size()}")
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
|
|
||||||
f"but is {attention_mask.size()}")
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
attn_weights = torch.max(attn_weights,
|
|
||||||
torch.tensor(torch.finfo(attn_weights.dtype).min))
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
|
||||||
bsz, self.num_heads, kv_seq_len, self.head_dim, device=device
|
|
||||||
)
|
|
||||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
|
||||||
key_states, value_states)
|
|
||||||
past_key_value = (key_states, value_states, before_hidden_states)
|
|
||||||
|
|
||||||
|
# IPEX-LLM OPT: sdp
|
||||||
|
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||||
|
import xe_addons
|
||||||
|
if use_quantize_kv:
|
||||||
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||||
|
attention_mask)
|
||||||
|
else:
|
||||||
|
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
||||||
|
attention_mask)
|
||||||
|
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||||
|
import xe_addons
|
||||||
|
if use_quantize_kv:
|
||||||
|
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
|
else:
|
||||||
|
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
k_cache, v_cache, _ = past_key_value
|
if use_quantize_kv:
|
||||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
|
||||||
key_states, value_states)
|
|
||||||
past_key_value = (key_states, value_states, before_hidden_states)
|
|
||||||
|
|
||||||
# torch.matmul
|
|
||||||
if query_states.size(2) != 1 or device.type != 'xpu':
|
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
query_states.dtype)
|
query_states.dtype)
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
attn_weights = torch.matmul(query_states,
|
||||||
else:
|
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
import xe_addons
|
|
||||||
attn_weights = xe_addons.query_key_fp8_matmul(query_states, key_states)
|
|
||||||
|
|
||||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
|
|
||||||
"Attention weights should be of size "
|
|
||||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, "
|
|
||||||
f"but is {attn_weights.size()}")
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
|
|
||||||
f"but is {attention_mask.size()}")
|
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
attn_weights = torch.max(attn_weights,
|
|
||||||
torch.tensor(torch.finfo(attn_weights.dtype).min))
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
dtype=torch.float32).to(value_states.dtype)
|
||||||
if query_states.size(2) != 1 or device.type != 'xpu':
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
else:
|
|
||||||
import xe_addons
|
|
||||||
attn_output = xe_addons.attn_value_fp8_matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
|
|
||||||
"`attn_output` should be of size "
|
|
||||||
f"{(bsz, self.num_heads, q_len, self.head_dim)}, "
|
|
||||||
f"but is {attn_output.size()}")
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
@ -307,146 +250,3 @@ def yuan_attention_forward_quantized(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
def yuan_attention_forward_origin(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
device = hidden_states.device
|
|
||||||
before_hidden_states = None
|
|
||||||
is_first_step = False
|
|
||||||
self.use_shareqk = False
|
|
||||||
|
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
|
|
||||||
|
|
||||||
invalidInputError(use_cache, "use_cache=True is needed")
|
|
||||||
invalidInputError(not self.use_shareqk, "use_shareqk is not supported for now")
|
|
||||||
|
|
||||||
if past_key_value is None:
|
|
||||||
is_first_step = True
|
|
||||||
if q_len >= 2:
|
|
||||||
before_hidden_states = hidden_states[:, -2:, :].transpose(0, 1).half()
|
|
||||||
else:
|
|
||||||
before_hidden_states = torch.zeros(2, bsz, self.hidden_size,
|
|
||||||
dtype=torch.half, device=hidden_states.device)
|
|
||||||
before_hidden_states[-1:, :, :] = hidden_states[:, -1:, :].transpose(0, 1)
|
|
||||||
else:
|
|
||||||
before_hidden_states = past_key_value[2]
|
|
||||||
this_hidden_states = torch.cat([
|
|
||||||
before_hidden_states,
|
|
||||||
hidden_states.transpose(0, 1).half(),
|
|
||||||
], dim=0)
|
|
||||||
before_hidden_states = this_hidden_states[-2:, :, ]
|
|
||||||
|
|
||||||
value_states = \
|
|
||||||
self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
if is_first_step:
|
|
||||||
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
|
|
||||||
None, hidden_states.dtype)
|
|
||||||
else:
|
|
||||||
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
|
|
||||||
this_hidden_states, hidden_states.dtype)
|
|
||||||
query_states = self.merged_q_proj(hidden_states)
|
|
||||||
key_states = self.merged_k_proj(hidden_states)
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
if use_fuse_rope:
|
|
||||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states,
|
|
||||||
key_states,
|
|
||||||
sin, cos,
|
|
||||||
"yuan",
|
|
||||||
position_ids)
|
|
||||||
else:
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states,
|
|
||||||
key_states,
|
|
||||||
cos, sin,
|
|
||||||
position_ids,
|
|
||||||
"yuan")
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
cache_k = past_key_value[0]
|
|
||||||
cache_v = past_key_value[1]
|
|
||||||
if not enough_kv_room:
|
|
||||||
# allocate new
|
|
||||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
cache_k.size(2),
|
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
|
||||||
dtype=cache_k.dtype,
|
|
||||||
device=device)
|
|
||||||
new_cache_k[:] = cache_k
|
|
||||||
new_cache_v[:] = cache_v
|
|
||||||
cache_k = new_cache_k
|
|
||||||
cache_v = new_cache_v
|
|
||||||
|
|
||||||
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
|
|
||||||
|
|
||||||
elif use_cache:
|
|
||||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
|
||||||
new_key_states, new_value_states = init_kv_cache(bsz,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
kv_seq_len,
|
|
||||||
max_cache_length,
|
|
||||||
dtype=key_states.dtype,
|
|
||||||
device=device)
|
|
||||||
new_key_states[:] = key_states
|
|
||||||
new_value_states[:] = value_states
|
|
||||||
key_states = new_key_states
|
|
||||||
value_states = new_value_states
|
|
||||||
|
|
||||||
past_key_value = \
|
|
||||||
(key_states, value_states, before_hidden_states) if use_cache else None
|
|
||||||
|
|
||||||
attn_weights = \
|
|
||||||
torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
|
|
||||||
"Attention weights should be of size "
|
|
||||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, "
|
|
||||||
f"but is {attn_weights.size()}")
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
|
|
||||||
f"but is {attention_mask.size()}")
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
attn_weights = torch.max(attn_weights,
|
|
||||||
torch.tensor(torch.finfo(attn_weights.dtype).min))
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = \
|
|
||||||
torch.nn.functional.softmax(attn_weights,
|
|
||||||
dim=-1,
|
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
|
|
||||||
"`attn_output` should be of size "
|
|
||||||
f"{(bsz, self.num_heads, q_len, self.head_dim)}, "
|
|
||||||
f"but is {attn_output.size()}")
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
if not output_attentions:
|
|
||||||
attn_weights = None
|
|
||||||
return attn_output, attn_weights, past_key_value
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue