refactor gemma to reduce old fuse rope usage (#12215)
This commit is contained in:
parent
9104a168f6
commit
a4a758656a
2 changed files with 150 additions and 348 deletions
|
|
@ -1017,6 +1017,10 @@ def _optimize_pre(model, qtype=None):
|
|||
model.apply(pre_process_attn_and_mlp)
|
||||
if model.config.model_type == "internvl_chat":
|
||||
_optimize_pre(model.language_model, qtype=qtype)
|
||||
if model.config.model_type == "gemma":
|
||||
from ipex_llm.transformers.models.gemma import merge_qkv, pre_compute_inv_freq
|
||||
model.apply(merge_qkv)
|
||||
model.apply(pre_compute_inv_freq)
|
||||
if model.config.model_type == "gemma2":
|
||||
from ipex_llm.transformers.models.gemma2 import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
|
|
@ -1846,32 +1850,16 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
module.MistralMLP,
|
||||
llama_mlp_forward)
|
||||
elif model.config.model_type == "gemma":
|
||||
invalidInputError(version.parse(trans_version) >= version.parse("4.38.0"),
|
||||
"Please upgrade transformers to 4.38.0 or higher version "
|
||||
"to run Mixtral models.")
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
if version.parse(trans_version) >= version.parse("4.39.0"):
|
||||
from ipex_llm.transformers.models.gemma import gemma_attention_forward_4_39
|
||||
convert_forward(model,
|
||||
module.GemmaAttention,
|
||||
gemma_attention_forward_4_39
|
||||
)
|
||||
else:
|
||||
from ipex_llm.transformers.models.gemma import gemma_attention_forward
|
||||
convert_forward(model,
|
||||
module.GemmaAttention,
|
||||
gemma_attention_forward,
|
||||
)
|
||||
from ipex_llm.transformers.models.gemma import gemma_model_forward
|
||||
from ipex_llm.transformers.models.gemma import gemma_attention_forward
|
||||
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
|
||||
from ipex_llm.transformers.models.gemma import gemma_mlp_forward
|
||||
convert_forward(model,
|
||||
module.GemmaRMSNorm,
|
||||
gemma_rms_norm_forward)
|
||||
convert_forward(model,
|
||||
module.GemmaMLP,
|
||||
gemma_mlp_forward)
|
||||
|
||||
from ipex_llm.transformers.models.common import mlp_gelu_forward
|
||||
convert_forward(model, module.GemmaModel, gemma_model_forward)
|
||||
convert_forward(model, module.GemmaAttention, gemma_attention_forward)
|
||||
convert_forward(model, module.GemmaRMSNorm, gemma_rms_norm_forward)
|
||||
convert_forward(model, module.GemmaMLP, mlp_gelu_forward)
|
||||
elif model.config.model_type == "gemma2":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
|
|||
|
|
@ -31,50 +31,31 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, GELU
|
||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rotate_half
|
||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||
from ipex_llm.transformers.kv import DynamicNormalCache
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
|
||||
import os
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.gemma.modeling_gemma import apply_rotary_pos_emb, repeat_kv
|
||||
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding, GemmaAttention
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
def merge_qkv(module: torch.nn.Module):
|
||||
merge_qkv_base(module, GemmaAttention)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
||||
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
|
||||
to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
|
||||
n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def should_use_fuse_rope(self, hidden_states, position_ids):
|
||||
use_fuse_rope = hidden_states.device.type == "xpu"
|
||||
use_fuse_rope = use_fuse_rope and not (self.training and hidden_states.requires_grad)
|
||||
use_fuse_rope = use_fuse_rope and position_ids is not None
|
||||
return use_fuse_rope
|
||||
def pre_compute_inv_freq(module: torch.nn.Module):
|
||||
if isinstance(module, GemmaRotaryEmbedding):
|
||||
module.inv_freq = 1.0 / (
|
||||
module.base **
|
||||
(torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim)
|
||||
)
|
||||
|
||||
|
||||
def gemma_rms_norm_forward(self, hidden_states):
|
||||
|
|
@ -91,29 +72,110 @@ def gemma_rms_norm_forward(self, hidden_states):
|
|||
return (1 + self.weight) * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
def gemma_mlp_forward(
|
||||
def gemma_model_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual=None
|
||||
) -> torch.Tensor:
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
bsz, hidden_size = x_2d.shape
|
||||
qtype = getattr(self.gate_proj, "qtype", None)
|
||||
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
|
||||
import xe_linear
|
||||
if not x_2d.is_contiguous():
|
||||
x_2d = x_2d.contiguous()
|
||||
out = self.down_proj(xe_linear.mlp_forward_xpu(
|
||||
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
||||
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
||||
GELU, qtype
|
||||
))
|
||||
else:
|
||||
out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
if residual is not None:
|
||||
return out + residual
|
||||
else:
|
||||
return out
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
# IPEX-LLM OPT start: kv cache and quantize kv cache
|
||||
if use_cache and not isinstance(past_key_values, DynamicNormalCache):
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||
# IPEX-LLM OPT end
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
invalidInputError((input_ids is None) ^ (inputs_embeds is None),
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, "
|
||||
"and must specify either one")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# IPEX-LLM changes start: support both transformers 4.38.1 and 4.39
|
||||
try:
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
|
||||
causal_mask = causal_mask[:, :, cache_position, :]
|
||||
except TypeError as _e:
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
# IPEX-LLM changes end
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# normalized
|
||||
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
def gemma_attention_forward(
|
||||
|
|
@ -126,111 +188,27 @@ def gemma_attention_forward(
|
|||
use_cache: bool=False,
|
||||
cache_position: Optional[torch.Tensor]=None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, hidden_size = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
# for flash attention
|
||||
original_dtype = hidden_states.dtype
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
bsz * q_len)
|
||||
|
||||
if decoding_fast_path:
|
||||
hidden_states = hidden_states.view(1, -1)
|
||||
|
||||
cache_k = past_key_value.key_cache[self.layer_idx]
|
||||
cache_v = past_key_value.value_cache[self.layer_idx]
|
||||
|
||||
kv_seq_len = cache_k.shape[-2]
|
||||
|
||||
import xe_linear
|
||||
query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states,
|
||||
self.q_proj.weight,
|
||||
self.k_proj.weight,
|
||||
self.v_proj.weight,
|
||||
position_ids,
|
||||
cache_k, cache_v,
|
||||
self.q_proj.weight.qtype,
|
||||
self.v_proj.weight.qtype,
|
||||
kv_seq_len,
|
||||
self.head_dim)
|
||||
kv_seq_len += 1
|
||||
|
||||
# update past_key_value's seem_tokens and kv caches.
|
||||
if self.layer_idx == 0:
|
||||
past_key_value.seen_tokens = kv_seq_len
|
||||
past_key_value.key_cache[self.layer_idx] = key_states
|
||||
past_key_value.value_cache[self.layer_idx] = value_states
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
||||
qkv = qkv.transpose(1, 2)
|
||||
query_states, key_states, value_states = qkv.split([self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
self.num_key_value_heads], dim=1)
|
||||
|
||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, None)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len,
|
||||
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len,
|
||||
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
invalidInputError(False,
|
||||
"The cache structure has changed since version v4.36. "
|
||||
f"If you are using {self.__class__.__name__} for "
|
||||
"auto-regressive decodingwith k/v caching, please make sure "
|
||||
"to initialize the attention class with a layer index.")
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
if use_fuse_rope:
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
|
||||
sin, cos, "gemma")
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, None)
|
||||
|
||||
if past_key_value is not None:
|
||||
# update the number of seen tokens
|
||||
if self.layer_idx == 0:
|
||||
past_key_value.seen_tokens += key_states.shape[-2]
|
||||
|
||||
# reuse k, v, self_attention
|
||||
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
|
||||
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||
past_key_value.key_cache.append(key_states)
|
||||
past_key_value.value_cache.append(value_states)
|
||||
else:
|
||||
cache_k = past_key_value.key_cache[self.layer_idx]
|
||||
cache_v = past_key_value.value_cache[self.layer_idx]
|
||||
|
||||
if not enough_kv_room:
|
||||
# allocate new
|
||||
new_c_k, new_c_v = extend_kv_cache(bsz,
|
||||
self.num_key_value_heads, # Support GQA
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
dtype=cache_k.dtype,
|
||||
device=device)
|
||||
|
||||
new_c_k[:] = cache_k
|
||||
new_c_v[:] = cache_v
|
||||
cache_k = new_c_k
|
||||
cache_v = new_c_v
|
||||
|
||||
key_states, value_states = append_kv_cache(cache_k, cache_v,
|
||||
key_states, value_states)
|
||||
|
||||
# update past_key_value
|
||||
past_key_value.key_cache[self.layer_idx] = key_states
|
||||
past_key_value.value_cache[self.layer_idx] = value_states
|
||||
if past_key_value is not None:
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
|
|
@ -238,26 +216,15 @@ def gemma_attention_forward(
|
|||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
if cache_position is not None:
|
||||
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
|
||||
else:
|
||||
causal_mask = attention_mask
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||
training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
|
@ -266,157 +233,4 @@ def gemma_attention_forward(
|
|||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output.to(original_dtype), attn_weights, past_key_value
|
||||
|
||||
|
||||
def gemma_attention_forward_4_39(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor]=None,
|
||||
position_ids: Optional[torch.LongTensor]=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]]=None,
|
||||
output_attentions: bool=False,
|
||||
use_cache: bool=False,
|
||||
cache_position: Optional[torch.Tensor]=None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, hidden_size = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
# for flash attention
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
bsz * q_len)
|
||||
|
||||
if decoding_fast_path:
|
||||
hidden_states = hidden_states.view(1, -1)
|
||||
|
||||
cache_k = past_key_value.key_cache[self.layer_idx]
|
||||
cache_v = past_key_value.value_cache[self.layer_idx]
|
||||
|
||||
kv_seq_len = cache_k.shape[-2]
|
||||
|
||||
import xe_linear
|
||||
query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states,
|
||||
self.q_proj.weight,
|
||||
self.k_proj.weight,
|
||||
self.v_proj.weight,
|
||||
position_ids,
|
||||
cache_k, cache_v,
|
||||
self.q_proj.weight.qtype,
|
||||
self.v_proj.weight.qtype,
|
||||
kv_seq_len,
|
||||
self.head_dim)
|
||||
kv_seq_len += 1
|
||||
|
||||
# update past_key_value's seem_tokens and kv caches.
|
||||
if self.layer_idx == 0:
|
||||
past_key_value._seen_tokens = kv_seq_len
|
||||
past_key_value.key_cache[self.layer_idx] = key_states
|
||||
past_key_value.value_cache[self.layer_idx] = value_states
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len,
|
||||
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len,
|
||||
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
invalidInputError(False,
|
||||
"The cache structure has changed since version v4.36. "
|
||||
f"If you are using {self.__class__.__name__} for "
|
||||
"auto-regressive decodingwith k/v caching, please make sure "
|
||||
"to initialize the attention class with a layer index.")
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
if use_fuse_rope:
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
|
||||
sin, cos, "gemma")
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, None)
|
||||
|
||||
if past_key_value is not None:
|
||||
# update the number of seen tokens
|
||||
if self.layer_idx == 0:
|
||||
past_key_value._seen_tokens += key_states.shape[-2]
|
||||
|
||||
# reuse k, v, self_attention
|
||||
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
|
||||
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||
past_key_value.key_cache.append(key_states)
|
||||
past_key_value.value_cache.append(value_states)
|
||||
else:
|
||||
cache_k = past_key_value.key_cache[self.layer_idx]
|
||||
cache_v = past_key_value.value_cache[self.layer_idx]
|
||||
|
||||
if not enough_kv_room:
|
||||
# allocate new
|
||||
new_c_k, new_c_v = extend_kv_cache(bsz,
|
||||
self.num_key_value_heads, # Support GQA
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
dtype=cache_k.dtype,
|
||||
device=device)
|
||||
|
||||
new_c_k[:] = cache_k
|
||||
new_c_v[:] = cache_v
|
||||
cache_k = new_c_k
|
||||
cache_v = new_c_v
|
||||
|
||||
key_states, value_states = append_kv_cache(cache_k, cache_v,
|
||||
key_states, value_states)
|
||||
|
||||
# update past_key_value
|
||||
past_key_value.key_cache[self.layer_idx] = key_states
|
||||
past_key_value.value_cache[self.layer_idx] = value_states
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
if cache_position is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
else:
|
||||
causal_mask = attention_mask
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||
training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output.to(original_dtype), attn_weights, past_key_value
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
|
|
|||
Loading…
Reference in a new issue