From d830a63bb7a3a133a2dd08702331c282d4c7b1be Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Mon, 20 May 2024 18:08:37 +0800 Subject: [PATCH] refactor qwen (#11074) --- .../llm/src/ipex_llm/transformers/convert.py | 34 -- .../src/ipex_llm/transformers/models/qwen.py | 551 ++++-------------- .../test_transformers_api_attention.py | 2 +- 3 files changed, 104 insertions(+), 483 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 639a2154..f2877bb8 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -717,40 +717,6 @@ def _optimize_pre(model): # baichuan2-7B from ipex_llm.transformers.models.baichuan2 import pre_compute_inv_freq model.apply(pre_compute_inv_freq) - if model.config.model_type == "qwen": - rope_base = model.config.rotary_emb_base - from accelerate.big_modeling import init_empty_weights - - def split_qkv_proj_func(module): - if "QWenAttention" in module.__class__.__name__: - c_attn_weight = module.c_attn.weight.data - c_attn_bias = module.c_attn.bias.data - # Compatible with AutoTP case - projection_size = c_attn_weight.shape[0] // 3 - hid_size = module.hidden_size - with init_empty_weights(): - q_proj = torch.nn.Linear(hid_size, projection_size) - k_proj = torch.nn.Linear(hid_size, projection_size) - v_proj = torch.nn.Linear(hid_size, projection_size) - if not model.config.to_dict().get("bigdl_transformers_low_bit", False): - q_proj.weight = torch.nn.Parameter( - c_attn_weight[:projection_size, :], requires_grad=False) - q_proj.bias = torch.nn.Parameter( - c_attn_bias[:projection_size], requires_grad=False) - k_proj.weight = torch.nn.Parameter( - c_attn_weight[projection_size: 2 * projection_size, :], requires_grad=False) - k_proj.bias = torch.nn.Parameter( - c_attn_bias[projection_size: 2 * projection_size], requires_grad=False) - v_proj.weight = torch.nn.Parameter( - c_attn_weight[2 * projection_size:, :], requires_grad=False) - v_proj.bias = torch.nn.Parameter( - c_attn_bias[2 * projection_size:], requires_grad=False) - module.q_proj = q_proj - module.k_proj = k_proj - module.v_proj = v_proj - module.rope_base = rope_base - del module.c_attn - model.apply(split_qkv_proj_func) if model.config.model_type == "stablelm": # For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b from ipex_llm.transformers.models.stablelm import merge_qkv diff --git a/python/llm/src/ipex_llm/transformers/models/qwen.py b/python/llm/src/ipex_llm/transformers/models/qwen.py index ffc95552..6aad5208 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen.py @@ -22,43 +22,24 @@ # LICENSE file in the root directory of this source tree. # -import importlib import math -from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List +from typing import Optional, Tuple, Union, Callable, List import torch import torch.nn.functional as F import torch.utils.checkpoint from transformers.utils import logging - -try: - from einops import rearrange -except ImportError: - rearrange = None - -from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache -from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ - restore_fp8_kv_cache, use_quantize_kv_cache +from ipex_llm.transformers.models.utils import update_past_key_value, should_use_fuse_rope +from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, use_quantize_kv_cache from ipex_llm.transformers.models.utils import rotate_half, SILU from ipex_llm.transformers.models.utils import mlp_fusion_check -from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu -from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8 -from ipex_llm.transformers.models.utils import use_decoding_fast_path +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal from ipex_llm.utils.common import invalidInputError, invalidOperationError -from ipex_llm.ggml.quantize import ggml_tensor_qtype from transformers.modeling_outputs import BaseModelOutputWithPast -apply_rotary_emb_func = None - -flash_attn_unpadded_func = None logger = logging.get_logger(__name__) -import os - -KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) -SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 - def apply_rotary_pos_emb(t, freqs): cos, sin = freqs @@ -71,56 +52,7 @@ def apply_rotary_pos_emb(t, freqs): return torch.cat((t_, t_pass_), dim=-1).type_as(t) -def should_use_fuse_rope(self, query_states): - use_fuse_rope = query_states.device.type == "xpu" - use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad) - return use_fuse_rope - - -def is_enough_kv_cache_room(layer_past, kv_seq_len=1): - # to determinate if is enough kv cache room in transformers between 4.31 and 4.35 - # seq_len for current seq len - # For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim] - if layer_past is None: - return False - else: - cache_k, cache_v = layer_past[0], layer_past[1] - cache_k = cache_k.transpose(1, 2) - cache_v = cache_v.transpose(1, 2) - return cache_k.stride(1) < (kv_seq_len + 1) * cache_k.size(3) - - def qwen_attention_forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - rotary_pos_emb_list: Optional[List[torch.Tensor]] = None, - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if use_quantize_kv_cache(self.q_proj, hidden_states): - forward_function = qwen_attention_forward_quantized - else: - forward_function = qwen_attention_forward_original - return forward_function( - self, - hidden_states, - rotary_pos_emb_list, - layer_past, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - output_attentions, - use_cache, - ) - - -def qwen_attention_forward_original( self, hidden_states: Optional[Tuple[torch.FloatTensor]], rotary_pos_emb_list: Optional[List[torch.Tensor]] = None, @@ -131,400 +63,121 @@ def qwen_attention_forward_original( encoder_attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, -): +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: invalidInputError(not self.use_flash_attn and not self.use_cache_quantization, "flash attn and kv_cache quantization are not supported") bsz, q_len, _ = hidden_states.size() device = hidden_states.device - # for flash attention - original_dtype = hidden_states.dtype + past_key_value = (None if layer_past is None + else (layer_past[0].transpose(1, 2), layer_past[1].transpose(1, 2))) + + qkv = self.c_attn(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_heads, + self.num_heads], dim=1) + + kv_seq_len = key_states.shape[2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[2] + + # IPEX-LLM OPT: fuse rope position_ids = rotary_pos_emb_list[-1] # the last one is posisiton_ids - rotary_pos_emb_list = rotary_pos_emb_list[:-1] - - use_fuse_rope = should_use_fuse_rope(self, hidden_states) - decoding_fast_path = use_decoding_fast_path(self.q_proj, - use_fuse_rope, - True, - bsz * q_len) - if decoding_fast_path: - hidden_states = hidden_states.view(1, -1) - cache_k, cache_v = layer_past[0], layer_past[1] - cache_k = cache_k.transpose(1, 2) - cache_v = cache_v.transpose(1, 2) - - kv_seq_len = cache_k.shape[-2] - base = self.rope_base - if is_enough_kv_cache_room(layer_past, kv_seq_len): - new_cache_k, new_cache_v = extend_kv_cache(bsz, - self.num_heads, - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=hidden_states.device) - new_cache_k[:] = cache_k - new_cache_v[:] = cache_v - cache_k = new_cache_k - cache_v = new_cache_v - - args = [hidden_states, self.q_proj.weight.data, self.k_proj.weight.data, - self.v_proj.weight.data, self.q_proj.bias.data, self.k_proj.bias.data, - self.v_proj.bias.data, position_ids, cache_k, cache_v, self.q_proj.weight.qtype, - self.v_proj.weight.qtype, kv_seq_len, self.head_dim, base] + inv_freq = rotary_pos_emb_list[-2] + rotary_pos_emb_list = rotary_pos_emb_list[:-2] + invalidInputError(len(rotary_pos_emb_list) == 1, + "rotary_pos_emb_list's length cannot be larger than 1") + use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training) + rotary_pos_emb = rotary_pos_emb_list[0] + if use_fuse_rope: + rot_dim = rotary_pos_emb[0].size(-1) import linear_q4_0 - query, key, value = linear_q4_0.forward_qkv_bias(*args) - kv_seq_len += 1 - query_size, key_size = 1, 1 + linear_q4_0.rotary_half_inplaced(inv_freq, position_ids, + query_states[..., :rot_dim], key_states[..., :rot_dim]) else: - query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - # TODO: speed up - # mixed_x_layer = self.c_attn(hidden_states) - # query, key, value = mixed_x_layer.split(self.split_size, dim=2) - - # query = self._split_heads(query, self.num_heads, self.head_dim) - # key = self._split_heads(key, self.num_heads, self.head_dim) - # value = self._split_heads(value, self.num_heads, self.head_dim) - if len(rotary_pos_emb_list) != 0: - cur_len = query.shape[1] - if len(rotary_pos_emb_list) == 1: - rotary_pos_emb = rotary_pos_emb_list[0] - rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] - if use_fuse_rope: - cos, sin = rotary_pos_emb - cos = cos.to(query.dtype) - sin = sin.to(query.dtype) - query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen") - else: - rotary_pos_emb = (rotary_pos_emb,) * 2 - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - query = apply_rotary_pos_emb(query, q_pos_emb) - key = apply_rotary_pos_emb(key, k_pos_emb) - else: - query_list = [] - key_list = [] - for i, rotary_pos_emb in enumerate(rotary_pos_emb_list): - rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] - if use_fuse_rope: - cos, sin = rotary_pos_emb - cos = cos.to(query.dtype) - sin = sin.to(query.dtype) - query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, - sin, cos, "qwen") - query_list += [query] - key_list += [key] - else: - rotary_pos_emb = (rotary_pos_emb,) * 2 - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)] - key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)] - query = torch.cat(query_list, dim=0) - key = torch.cat(key_list, dim=0) - query_size, key_size = query.size(1), key.size(1) - kv_seq_len = key_size if layer_past is None else key_size + layer_past[0].size(1) + rotary_pos_emb = [i[:, -q_len:, :, :].transpose(1, 2) for i in rotary_pos_emb] + query_states = apply_rotary_pos_emb(query_states, rotary_pos_emb) + key_states = apply_rotary_pos_emb(key_states, rotary_pos_emb) if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training: - seq_start = kv_seq_len - query_size + seq_start = kv_seq_len - q_len seq_end = kv_seq_len - logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) - query = query * logn_tensor.expand_as(query) + logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].transpose(1, 2) + query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states) - if query_size > 1: - causal_mask = torch.tril( - torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query.device) - ).view(1, 1, kv_seq_len, kv_seq_len) - causal_mask = causal_mask[ - :, :, kv_seq_len - query_size:kv_seq_len, :kv_seq_len - ] - else: - causal_mask = None - - if layer_past is not None: - if not decoding_fast_path: - cache_k, cache_v = layer_past[0], layer_past[1] - cache_k = cache_k.transpose(1, 2) - cache_v = cache_v.transpose(1, 2) - if cache_k.stride(1) < kv_seq_len * cache_k.size(3): - new_cache_k, new_cache_v = extend_kv_cache(bsz, - self.num_heads, - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=hidden_states.device) - new_cache_k[:] = cache_k - new_cache_v[:] = cache_v - cache_k = new_cache_k - cache_v = new_cache_v - key_states, value_states = append_kv_cache(cache_k, cache_v, - key.transpose(1, 2), value.transpose(1, 2)) - key = key_states - value = value_states - elif use_cache: - max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = init_kv_cache(bsz, - self.num_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key.dtype, - device=hidden_states.device) - new_key_states[:] = key.transpose(1, 2) - new_value_states[:] = value.transpose(1, 2) - key = new_key_states - value = new_value_states - - if not decoding_fast_path: - query = query.transpose(1, 2) + # IPEX-LLM OPT: kv cache and quantzie kv cache + use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states) + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, device + ) + past_key_value = (key_states.transpose(1, 2), + value_states.transpose(1, 2)) if use_cache else None + # IPEX-LLM OPT: sdp + attn_weights = None if not self.training and not hidden_states.requires_grad and \ - use_flash_attention(query, key): - attn_output = F.scaled_dot_product_attention(query.to(device, dtype=torch.float16), - key.to(device, dtype=torch.float16), - value.to(device, dtype=torch.float16), - is_causal=True) - attn_output = attn_output.view(query.shape) - attn_output = attn_output.transpose(1, 2) - attn_weights = None - elif not self.training and not hidden_states.requires_grad and \ - use_sdp(q_len, key.shape[2], self.head_dim, query): + use_flash_attention(query_states, key_states, attention_mask): + attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16), + key_states.to(dtype=torch.float16), + value_states.to(dtype=torch.float16), + is_causal=True).to(hidden_states.dtype) + elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): import linear_q4_0 - attn_output = linear_q4_0.sdp(query, key, value, attention_mask) - attn_output = attn_output.view(query.shape) - attn_output = attn_output.transpose(1, 2) - attn_weight = None - else: - attn_output, attn_weight = self._attn( - query.to(key.dtype), key, value, causal_mask, attention_mask, head_mask - ) - - context_layer = self._merge_heads( - attn_output, self.num_heads, self.head_dim - ) - - attn_output = self.c_proj(context_layer).to(original_dtype) - - if use_cache: - outputs = (attn_output, (key.transpose(1, 2), value.transpose(1, 2))) - else: - outputs = (attn_output, None) - if output_attentions: - outputs += (attn_weight,) - - return outputs - - -def qwen_attention_forward_quantized( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - rotary_pos_emb_list: Optional[List[torch.Tensor]] = None, - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, -): - invalidInputError(not self.use_flash_attn and not self.use_cache_quantization, - "flash attn and kv_cache quantization are not supported") - - bsz, q_len, _ = hidden_states.size() - device = hidden_states.device - position_ids = rotary_pos_emb_list[-1] # the last one is posisiton_ids - rotary_pos_emb_list = rotary_pos_emb_list[:-1] - - use_fuse_rope = should_use_fuse_rope(self, hidden_states) - # qtype_check = decoding_fast_path_qtype_check(self.q_proj) - # TODO: use when decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1) - decoding_fast_path = False - if decoding_fast_path: - hidden_states = hidden_states.view(1, -1) - tmp_cache_k, tmp_cache_v = init_kv_cache( - bsz, - self.num_heads, - self.head_dim, - 0, - 1, - dtype=hidden_states.dtype, - device=device - ) - - base = self.rope_base - - args = [hidden_states, self.q_proj.weight.data, self.k_proj.weight.data, - self.v_proj.weight.data, self.q_proj.bias.data, self.k_proj.bias.data, - self.v_proj.bias.data, position_ids, tmp_cache_k, tmp_cache_v, - self.q_proj.weight.qtype, self.v_proj.weight.qtype, 0, self.head_dim, base] - import linear_q4_0 - query, key, value = linear_q4_0.forward_qkv_bias(*args) - self.kv_seq_len += 1 - kv_seq_len = self.kv_seq_len - query_size, key_size = 1, 1 - else: - query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - # TODO: speed up - # mixed_x_layer = self.c_attn(hidden_states) - # query, key, value = mixed_x_layer.split(self.split_size, dim=2) - - # query = self._split_heads(query, self.num_heads, self.head_dim) - # key = self._split_heads(key, self.num_heads, self.head_dim) - # value = self._split_heads(value, self.num_heads, self.head_dim) - if rotary_pos_emb_list is not None: - cur_len = query.shape[1] - if len(rotary_pos_emb_list) == 1: - rotary_pos_emb = rotary_pos_emb_list[0] - rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] - if use_fuse_rope: - cos, sin = rotary_pos_emb - cos = cos.to(query.dtype) - sin = sin.to(query.dtype) - query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen") - else: - rotary_pos_emb = (rotary_pos_emb,) * 2 - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - query = apply_rotary_pos_emb(query, q_pos_emb) - key = apply_rotary_pos_emb(key, k_pos_emb) - else: - query_list = [] - key_list = [] - for i, rotary_pos_emb in enumerate(rotary_pos_emb_list): - rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] - if use_fuse_rope: - cos, sin = rotary_pos_emb - cos = cos.to(query.dtype) - sin = sin.to(query.dtype) - query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, - sin, cos, "qwen") - query_list += [query] - key_list += [key] - else: - rotary_pos_emb = (rotary_pos_emb,) * 2 - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)] - key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)] - query = torch.cat(query_list, dim=0) - key = torch.cat(key_list, dim=0) - query_size, key_size = query.size(1), key.size(1) - kv_seq_len = key_size if layer_past is None else key_size + layer_past[0].size(1) - - if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training: - seq_start = kv_seq_len - query_size - seq_end = kv_seq_len - logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) - query = query * logn_tensor.expand_as(query) - - if query_size > 1: - causal_mask = torch.tril( - torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query.device) - ).view(1, 1, kv_seq_len, kv_seq_len) - causal_mask = causal_mask[ - :, :, kv_seq_len - query_size:kv_seq_len, :kv_seq_len - ] - else: - causal_mask = None - - if layer_past is None: - query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) - # query, key, value's shape: [bs, num_heads, seq_len, head_dim] - - # save kv seq len for decoding_fast_path - self.kv_seq_len = key.shape[-2] - # For first token, use original attn - attn_output, attn_weight = self._attn( - query, key, value, causal_mask, attention_mask, head_mask - ) - if use_cache: - max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - k_cache, v_cache = init_fp8_kv_cache( - query.size(0), self.num_heads, kv_seq_len, self.head_dim, - device=query.device - ) - key, value = append_fp8_kv_cache(k_cache, v_cache, key, value) - else: - if decoding_fast_path: - k_cache, v_cache = layer_past[0], layer_past[1] - # k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim] + if use_quantize_kv: + attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states) else: - query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) - k_cache, v_cache = layer_past[0], layer_past[1] - - k_cache = k_cache.transpose(1, 2) - v_cache = v_cache.transpose(1, 2) - # k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim] - - key, value = append_fp8_kv_cache(k_cache, v_cache, key, value) - - attn_output, attn_weight = core_attn( - self, query, key, value, causal_mask, attention_mask, head_mask - ) - - context_layer = self._merge_heads( - attn_output, self.num_heads, self.head_dim - ) - - attn_output = self.c_proj(context_layer) - - if use_cache: - outputs = (attn_output, (key.transpose(1, 2), value.transpose(1, 2))) + attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states) else: - outputs = (attn_output, None) - if output_attentions: - outputs += (attn_weight,) - - return outputs - - -def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None): - if not use_sdp_fp8(query.size(2), key.size(2), query): - # We have no CPU fp8 matmul implementation for now, so just upscale to fp32 - key, value = restore_fp8_kv_cache(key, value, query.dtype) - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: - if self.use_cache_quantization: - size_temp = value[0].size(-1) - else: - size_temp = value.size(-1) - attn_weights = attn_weights / (size_temp ** 0.5) - - mask_value = torch.finfo(attn_weights.dtype).min - if causal_mask is not None: - attn_weights = torch.where( - causal_mask, attn_weights.to(attn_weights.dtype), mask_value - ) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - if self.softmax_in_fp32: - attn_weights = torch.nn.functional.softmax(attn_weights.float(), dim=-1) + if q_len > 1: + causal_mask = torch.tril( + torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query_states.device) + ).view(1, 1, kv_seq_len, kv_seq_len) + causal_mask = causal_mask[ + :, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len + ] + attention_mask = torch.zeros(causal_mask.shape, dtype=query_states.dtype, + device=query_states.device) + attention_mask.masked_fill_(causal_mask.logical_not(), + torch.finfo(attention_mask.dtype).min) + attention_mask = attention_mask.expand([bsz, -1, -1, -1]) else: - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attention_mask = None - attn_weights = attn_weights.type(query.dtype) - attn_weights = self.attn_dropout(attn_weights) + if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): + import linear_q4_0 + if use_quantize_kv: + attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, + attention_mask) + else: + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, + attention_mask) + else: + if use_quantize_kv: + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + if self.softmax_in_fp32: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to( + value_states.dtype) + else: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, value_states) - if head_mask is not None: - attn_weights = attn_weights * head_mask + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) - # We have no CPU fp8 matmul implementation for now, so just upscale to fp32 - attn_output = torch.matmul(attn_weights, value) + attn_output = self.c_proj(attn_output) + + if output_attentions: + return attn_output, past_key_value, attn_weights else: - import linear_q4_0 - attn_output = linear_q4_0.sdp_fp8(query, key, value, - attention_mask) - attn_weights = None - attn_output = attn_output.transpose(1, 2) - - return attn_output, attn_weights + return attn_output, past_key_value def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: @@ -652,9 +305,11 @@ def qwen_model_forward( ntk_alpha = self.get_ntk_alpha(kv_seq_len) ntk_alpha_list.append(ntk_alpha) self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list + # ipex-llm changes rotary_pos_emb_list = [ self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list - ] + [position_ids] + ] + [self.rotary_emb.inv_freq.to(self.dtype), position_ids] + # ipex-llm changes ends hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) @@ -695,7 +350,7 @@ def qwen_model_forward( encoder_attention_mask, ) else: - # bigdl-llm changes + # ipex-llm changes curr_device = block.ln_1.weight.device from accelerate.utils.operations import send_to_device if rotary_pos_emb_list is not None: @@ -709,7 +364,7 @@ def qwen_model_forward( if encoder_attention_mask is not None: encoder_attention_mask = send_to_device(encoder_attention_mask, curr_device) - # bigdl-llm changes ends + # ipex-llm changes ends outputs = block( hidden_states, diff --git a/python/llm/test/inference_gpu/test_transformers_api_attention.py b/python/llm/test/inference_gpu/test_transformers_api_attention.py index 0990f8ad..b03ddaf9 100644 --- a/python/llm/test/inference_gpu/test_transformers_api_attention.py +++ b/python/llm/test/inference_gpu/test_transformers_api_attention.py @@ -188,5 +188,5 @@ class Test_Optimize_Gpu_Model: # currently only need to compare the output of one self-attention layer. layer_norm = "transformer.h.31.ln_1" self_attn = "transformer.h.31.attn" - lower_bound = 8e-3 + lower_bound = 2e-2 self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, self_attn, layer_norm, lower_bound) \ No newline at end of file