diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index f2877bb8..22985743 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -717,6 +717,10 @@ def _optimize_pre(model): # baichuan2-7B from ipex_llm.transformers.models.baichuan2 import pre_compute_inv_freq model.apply(pre_compute_inv_freq) + # for qwen2 + if model.config.model_type == "qwen2": + from ipex_llm.transformers.models.qwen2 import merge_qkv + model.apply(merge_qkv) if model.config.model_type == "stablelm": # For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b from ipex_llm.transformers.models.stablelm import merge_qkv diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index d599152e..c02d28f9 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -42,59 +42,24 @@ import warnings from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List import torch -import torch.nn as nn -import torch.nn.functional as F +from torch.nn.functional import scaled_dot_product_attention as sdpa -from ipex_llm.transformers.models.llama import repeat_kv -from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache +from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache -from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36 -from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu -from ipex_llm.transformers.kv import DynamicFp8Cache +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal +from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache from ipex_llm.utils.common import invalidInputError -from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp -from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb + +from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast -from ipex_llm.transformers.models.utils import use_decoding_fast_path - -try: - from transformers.cache_utils import Cache, DynamicCache -except ImportError: - Cache = Tuple[torch.Tensor] -import logging +from transformers.cache_utils import Cache, DynamicCache from transformers import logging logger = logging.get_logger(__name__) -import os - -KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) - - -def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, output_attentions): - if not output_attentions: - if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None: - return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1" - elif query_states.dtype == torch.float16 and \ - query_states.shape[2] >= 5000: - # split tensor for memory block limitation - # support fp16 and set input length threshold at 5000 for now - return True - elif query_states.element_size()*bsz*num_heads*q_len*kv_seq_len >= 4*1024**3: - # attn_weight size larger than memory block limitation 4GB - return True - return False - - -def should_use_fuse_rope(self, query_states, position_ids): - use_fuse_rope = query_states.device.type == "xpu" - use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad) - use_fuse_rope = use_fuse_rope and position_ids is not None - return use_fuse_rope - def qwen2_model_forward( self, @@ -109,9 +74,12 @@ def qwen2_model_forward( return_dict: Optional[bool] = None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache - if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids): - if not isinstance(past_key_values, DynamicFp8Cache): + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids) + if use_cache: + if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) return qwen2_model_forward_internal( self=self, input_ids=input_ids, @@ -248,13 +216,13 @@ def qwen2_model_forward_internal( use_cache, ) else: - # bigdl-llm changes + # ipex-llm changes curr_device = decoder_layer.input_layernorm.weight.device if attention_mask is not None: attention_mask = attention_mask.to(curr_device) if position_ids is not None: position_ids = position_ids.to(curr_device) - # bigdl-llm changes end + # ipex-llm changes end layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, @@ -294,325 +262,111 @@ def qwen2_model_forward_internal( ) +def merge_qkv(module: torch.nn.Module): + if isinstance(module, Qwen2Attention): + new_weight = torch.cat([ + module.q_proj.weight.data, + module.k_proj.weight.data, + module.v_proj.weight.data, + ], dim=0) + new_bias = torch.cat([ + module.q_proj.bias.data, + module.k_proj.bias.data, + module.v_proj.bias.data, + ], dim=-1) + + qkv_proj = torch.nn.Linear(0, 0, bias=True) + qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False) + qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False) + qkv_proj.in_features = new_weight.size(1) + qkv_proj.out_features = new_weight.size(0) + module.qkv_proj = qkv_proj + + del module.q_proj, module.k_proj, module.v_proj + + def qwen2_attention_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if use_quantize_kv_cache(self.q_proj, hidden_states): - forward_function = qwen2_attention_forward_quantized - elif hidden_states.device.type == "cpu": - forward_function = qwen2_sdpa_attention_forward - else: - forward_function = qwen2_attention_forward_origin - return forward_function( - self=self, - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - **kwargs, - ) - - -def qwen2_attention_forward_quantized( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[DynamicFp8Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. " - "Please make sure use `attention_mask` instead.`" - ) - use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) bsz, q_len, _ = hidden_states.size() + device = hidden_states.device - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, - self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], dim=1) kv_seq_len = key_states.shape[-2] if past_key_value is not None: - invalidInputError(self.layer_idx is not None, - "The cache structure has changed since version v4.36. " - f"If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, " - "please make sure to initialize the attention class " - "with a layer index.") kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - if use_fuse_rope: - query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, - sin, cos, "qwen2", - position_ids) + if should_use_fuse_rope(hidden_states, position_ids, self.training): + import linear_q4_0 + linear_q4_0.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + self.layer_idx, None) - if q_len == 1 and query_states.device.type == 'xpu' and not self.training \ - and not hidden_states.requires_grad: - import linear_q4_0 - attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, - attention_mask) - attn_weights = None - else: - key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) - key = repeat_kv(key, self.num_key_value_groups) - value = repeat_kv(value, self.num_key_value_groups) - if should_split_qkv_tensor(query_states, bsz, self.num_heads, - q_len, kv_seq_len, output_attentions): - attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states, key, - value, attention_mask, - bsz, q_len, kv_seq_len, - self.head_dim, self.num_heads, - self.attention_dropout, - self.training) - else: - attn_weights = torch.matmul(query_states, key.transpose(2, 3)) - attn_weights = attn_weights / math.sqrt(self.head_dim) - - invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len), - ("Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}," - "but is {attn_weights.size()}")) - - if attention_mask is not None: - invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), - (f"Attention mask should be of size " - f"{(bsz, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}")) - - attn_weights = attn_weights + attention_mask - - if kv_seq_len >= 2048 or bsz >= 64: - # for memory considerations, do not upcast attention to fp32 - # for long sequences or large batches - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - else: - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, - training=self.training) - - attn_output = torch.matmul(attn_weights, value) - - invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), - "`attn_output` should be of size " - f"{(bsz, self.num_heads, q_len, self.head_dim)}," - f" but is {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value -from ipex_llm.ggml.quantize import ggml_tensor_qtype -SYM_INT4 = ggml_tensor_qtype["sym_int4"] -FP8E5 = ggml_tensor_qtype["fp8_e5m2"] - - -def qwen2_attention_forward_origin( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - - use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. " - "Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - device = hidden_states.device - - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) - decoding_fast_path = use_decoding_fast_path(self.q_proj, - use_fuse_rope, - enough_kv_room, - bsz * q_len) - if decoding_fast_path: - hidden_states = hidden_states.view(1, -1) - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] - kv_seq_len = cache_k.shape[-2] - import linear_q4_0 - args = [hidden_states, self.q_proj.weight, self.k_proj.weight, self.v_proj.weight, - self.q_proj.bias, self.k_proj.bias, self.v_proj.bias, position_ids, cache_k, - cache_v, self.q_proj.weight.qtype, self.v_proj.weight.qtype, kv_seq_len, - self.head_dim, self.rotary_emb.base] - query_states, key_states, value_states = linear_q4_0.forward_qkv_bias(*args) - kv_seq_len += 1 - if self.layer_idx == 0: - past_key_value.seen_tokens = kv_seq_len - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states - - else: - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = \ - key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = \ - value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - invalidInputError( - False, - "The cache structure has changed since version v4.36. " - f"If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, " - "please make sure to initialize the attention class with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - if use_fuse_rope: - query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, - sin, cos, "qwen2", - position_ids) - else: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids) - - if past_key_value is not None: - # update the number of seen tokens - if self.layer_idx == 0: - past_key_value.seen_tokens += key_states.shape[-2] - - if len(past_key_value.key_cache) <= self.layer_idx: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) - else: - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] - - if not enough_kv_room: - # allocate new - new_c_k, new_c_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) - - new_c_k[:] = cache_k - new_c_v[:] = cache_v - cache_k = new_c_k - cache_v = new_c_v - - key_states, value_states = append_kv_cache(cache_k, - cache_v, - key_states, - value_states) - - # update past_key_value - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if not self.training and not hidden_states.requires_grad and \ - use_flash_attention(query_states, key_states, attention_mask): - attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16), - key_states.to(device, dtype=torch.float16), - value_states.to(device, dtype=torch.float16), - is_causal=True) - attn_weights = None + attn_weights = None + if query_states.device.type == "cpu": + attn_output = sdpa(query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=self.is_causal and attention_mask is None and q_len > 1) elif not self.training and not hidden_states.requires_grad and \ - use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + use_flash_attention(query_states, key_states, attention_mask): + attn_output = sdpa(query_states.to(device, dtype=torch.float16), + key_states.to(device, dtype=torch.float16), + value_states.to(device, dtype=torch.float16), + is_causal=True).to(hidden_states.dtype) + elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states): import linear_q4_0 - attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) - attn_output = attn_output.view(query_states.shape) - attn_weights = None - else: - if should_split_qkv_tensor(query_states, bsz, self.num_heads, - q_len, kv_seq_len, output_attentions): - attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states, key_states, - value_states, attention_mask, - bsz, q_len, kv_seq_len, - self.head_dim, self.num_heads, - self.attention_dropout, - self.training) + if isinstance(past_key_value, DynamicFp8Cache): + attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, + attention_mask) else: - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) + elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): + import linear_q4_0 + if isinstance(past_key_value, DynamicFp8Cache): + attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states) + else: + attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states) + else: + if isinstance(past_key_value, DynamicFp8Cache): + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len), - ("Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}," - "but is {attn_weights.size()}")) - - if attention_mask is not None: - invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), - (f"Attention mask should be of size " - f"{(bsz, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}")) - - attn_weights = attn_weights + attention_mask - - if kv_seq_len >= 2048 or bsz >= 64: - # for memory considerations, do not upcast attention to fp32 - # for long sequences or large batches - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - else: - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, - p=self.attention_dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), - "`attn_output` should be of size " - f"{(bsz, self.num_heads, q_len, self.head_dim)}," - f" but is {attn_output.size()}") + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -621,183 +375,4 @@ def qwen2_attention_forward_origin( if not output_attentions: attn_weights = None - - return attn_output.to(hidden_states.dtype), attn_weights, past_key_value - - -def qwen2_sdpa_attention_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - - use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. " - "Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - device = hidden_states.device - - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) - decoding_fast_path = use_decoding_fast_path(self.q_proj, - use_fuse_rope, - enough_kv_room, - bsz * q_len) - if decoding_fast_path: - hidden_states = hidden_states.view(1, -1) - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] - kv_seq_len = cache_k.shape[-2] - import linear_q4_0 - args = [hidden_states, self.q_proj.weight, self.k_proj.weight, self.v_proj.weight, - self.q_proj.bias, self.k_proj.bias, self.v_proj.bias, position_ids, cache_k, - cache_v, self.q_proj.weight.qtype, self.v_proj.weight.qtype, kv_seq_len, - self.head_dim, self.rotary_emb.base] - query_states, key_states, value_states = linear_q4_0.forward_qkv_bias(*args) - kv_seq_len += 1 - if self.layer_idx == 0: - past_key_value.seen_tokens = kv_seq_len - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states - - else: - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = \ - key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = \ - value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - invalidInputError( - False, - "The cache structure has changed since version v4.36. " - f"If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, " - "please make sure to initialize the attention class with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - if use_fuse_rope: - query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, - sin, cos, "qwen2", - position_ids) - else: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids) - - if past_key_value is not None: - # update the number of seen tokens - if self.layer_idx == 0: - past_key_value.seen_tokens += key_states.shape[-2] - - if len(past_key_value.key_cache) <= self.layer_idx: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) - else: - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] - - if not enough_kv_room: - # allocate new - new_c_k, new_c_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) - - new_c_k[:] = cache_k - new_c_v[:] = cache_v - cache_k = new_c_k - cache_v = new_c_v - - key_states, value_states = append_kv_cache(cache_k, - cache_v, - key_states, - value_states) - - # update past_key_value - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len), - ("Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}," - "but is {attn_weights.size()}")) - - if attention_mask is not None: - invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), - (f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}" - f" but is {attention_mask.size()}")) - - attn_weights = attn_weights + attention_mask - - from torch.nn.functional import scaled_dot_product_attention as sdpa - attn_output = sdpa(query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None and q_len > 1) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -def native_sdp_split_qkv_tensor(query, key, value, attention_mask, - bsz, q_len, kv_seq_len, head_dim, num_heads, - attention_dropout, training): - block_size = 8 - query_split = torch.split(query, block_size, dim=1) - key_split = torch.split(key.transpose(2, 3), block_size, dim=1) - value_split = torch.split(value, block_size, dim=1) - attn_outputs = [] - for q, k, v in zip(query_split, key_split, value_split): - attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim) - block_actual_size = attn_weights_split.size(1) - attn_weights_split_size = (bsz, block_actual_size, q_len, kv_seq_len) - if attn_weights_split.size() != attn_weights_split_size: - invalidInputError(False, - f"Splitted attention weights should be of size " - f"{attn_weights_split_size}, but is {attn_weights_split.size()}") - - if attention_mask is not None: - attn_mask_size = (bsz, 1, q_len, kv_seq_len) - if attention_mask.size() != attn_mask_size: - invalidInputError(False, - f"Attention mask should be of size {attn_mask_size}, " - f"but is {attention_mask.size()}") - attn_weights_split = attn_weights_split + attention_mask - attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1) - attn_weights_split = nn.functional.dropout(attn_weights_split, - p=attention_dropout, - training=training) - attn_outputs.append(torch.matmul(attn_weights_split, v)) - attn_output = torch.cat(attn_outputs, dim=1) - return attn_output, None + return attn_output, attn_weights, past_key_value