refactor qwen (#11074)
This commit is contained in:
parent
74950a152a
commit
d830a63bb7
3 changed files with 104 additions and 483 deletions
|
|
@ -717,40 +717,6 @@ def _optimize_pre(model):
|
||||||
# baichuan2-7B
|
# baichuan2-7B
|
||||||
from ipex_llm.transformers.models.baichuan2 import pre_compute_inv_freq
|
from ipex_llm.transformers.models.baichuan2 import pre_compute_inv_freq
|
||||||
model.apply(pre_compute_inv_freq)
|
model.apply(pre_compute_inv_freq)
|
||||||
if model.config.model_type == "qwen":
|
|
||||||
rope_base = model.config.rotary_emb_base
|
|
||||||
from accelerate.big_modeling import init_empty_weights
|
|
||||||
|
|
||||||
def split_qkv_proj_func(module):
|
|
||||||
if "QWenAttention" in module.__class__.__name__:
|
|
||||||
c_attn_weight = module.c_attn.weight.data
|
|
||||||
c_attn_bias = module.c_attn.bias.data
|
|
||||||
# Compatible with AutoTP case
|
|
||||||
projection_size = c_attn_weight.shape[0] // 3
|
|
||||||
hid_size = module.hidden_size
|
|
||||||
with init_empty_weights():
|
|
||||||
q_proj = torch.nn.Linear(hid_size, projection_size)
|
|
||||||
k_proj = torch.nn.Linear(hid_size, projection_size)
|
|
||||||
v_proj = torch.nn.Linear(hid_size, projection_size)
|
|
||||||
if not model.config.to_dict().get("bigdl_transformers_low_bit", False):
|
|
||||||
q_proj.weight = torch.nn.Parameter(
|
|
||||||
c_attn_weight[:projection_size, :], requires_grad=False)
|
|
||||||
q_proj.bias = torch.nn.Parameter(
|
|
||||||
c_attn_bias[:projection_size], requires_grad=False)
|
|
||||||
k_proj.weight = torch.nn.Parameter(
|
|
||||||
c_attn_weight[projection_size: 2 * projection_size, :], requires_grad=False)
|
|
||||||
k_proj.bias = torch.nn.Parameter(
|
|
||||||
c_attn_bias[projection_size: 2 * projection_size], requires_grad=False)
|
|
||||||
v_proj.weight = torch.nn.Parameter(
|
|
||||||
c_attn_weight[2 * projection_size:, :], requires_grad=False)
|
|
||||||
v_proj.bias = torch.nn.Parameter(
|
|
||||||
c_attn_bias[2 * projection_size:], requires_grad=False)
|
|
||||||
module.q_proj = q_proj
|
|
||||||
module.k_proj = k_proj
|
|
||||||
module.v_proj = v_proj
|
|
||||||
module.rope_base = rope_base
|
|
||||||
del module.c_attn
|
|
||||||
model.apply(split_qkv_proj_func)
|
|
||||||
if model.config.model_type == "stablelm":
|
if model.config.model_type == "stablelm":
|
||||||
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
|
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
|
||||||
from ipex_llm.transformers.models.stablelm import merge_qkv
|
from ipex_llm.transformers.models.stablelm import merge_qkv
|
||||||
|
|
|
||||||
|
|
@ -22,43 +22,24 @@
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
#
|
#
|
||||||
|
|
||||||
import importlib
|
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
|
from typing import Optional, Tuple, Union, Callable, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
from ipex_llm.transformers.models.utils import update_past_key_value, should_use_fuse_rope
|
||||||
try:
|
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, use_quantize_kv_cache
|
||||||
from einops import rearrange
|
|
||||||
except ImportError:
|
|
||||||
rearrange = None
|
|
||||||
|
|
||||||
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
|
|
||||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
|
||||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
|
||||||
from ipex_llm.transformers.models.utils import rotate_half, SILU
|
from ipex_llm.transformers.models.utils import rotate_half, SILU
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
|
||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
|
|
||||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
|
||||||
from ipex_llm.utils.common import invalidInputError, invalidOperationError
|
from ipex_llm.utils.common import invalidInputError, invalidOperationError
|
||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
|
||||||
apply_rotary_emb_func = None
|
|
||||||
|
|
||||||
flash_attn_unpadded_func = None
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
|
||||||
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(t, freqs):
|
def apply_rotary_pos_emb(t, freqs):
|
||||||
cos, sin = freqs
|
cos, sin = freqs
|
||||||
|
|
@ -71,56 +52,7 @@ def apply_rotary_pos_emb(t, freqs):
|
||||||
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
|
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
|
||||||
|
|
||||||
|
|
||||||
def should_use_fuse_rope(self, query_states):
|
|
||||||
use_fuse_rope = query_states.device.type == "xpu"
|
|
||||||
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
|
|
||||||
return use_fuse_rope
|
|
||||||
|
|
||||||
|
|
||||||
def is_enough_kv_cache_room(layer_past, kv_seq_len=1):
|
|
||||||
# to determinate if is enough kv cache room in transformers between 4.31 and 4.35
|
|
||||||
# seq_len for current seq len
|
|
||||||
# For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim]
|
|
||||||
if layer_past is None:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
cache_k, cache_v = layer_past[0], layer_past[1]
|
|
||||||
cache_k = cache_k.transpose(1, 2)
|
|
||||||
cache_v = cache_v.transpose(1, 2)
|
|
||||||
return cache_k.stride(1) < (kv_seq_len + 1) * cache_k.size(3)
|
|
||||||
|
|
||||||
|
|
||||||
def qwen_attention_forward(
|
def qwen_attention_forward(
|
||||||
self,
|
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
|
||||||
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
|
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
use_cache: Optional[bool] = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
|
||||||
forward_function = qwen_attention_forward_quantized
|
|
||||||
else:
|
|
||||||
forward_function = qwen_attention_forward_original
|
|
||||||
return forward_function(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
rotary_pos_emb_list,
|
|
||||||
layer_past,
|
|
||||||
attention_mask,
|
|
||||||
head_mask,
|
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
output_attentions,
|
|
||||||
use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def qwen_attention_forward_original(
|
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
|
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
|
||||||
|
|
@ -131,400 +63,121 @@ def qwen_attention_forward_original(
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
):
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
invalidInputError(not self.use_flash_attn and not self.use_cache_quantization,
|
invalidInputError(not self.use_flash_attn and not self.use_cache_quantization,
|
||||||
"flash attn and kv_cache quantization are not supported")
|
"flash attn and kv_cache quantization are not supported")
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
# for flash attention
|
past_key_value = (None if layer_past is None
|
||||||
original_dtype = hidden_states.dtype
|
else (layer_past[0].transpose(1, 2), layer_past[1].transpose(1, 2)))
|
||||||
|
|
||||||
|
qkv = self.c_attn(hidden_states)
|
||||||
|
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
|
||||||
|
qkv = qkv.transpose(1, 2)
|
||||||
|
query_states, key_states, value_states = qkv.split([self.num_heads,
|
||||||
|
self.num_heads,
|
||||||
|
self.num_heads], dim=1)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[2]
|
||||||
|
|
||||||
|
# IPEX-LLM OPT: fuse rope
|
||||||
position_ids = rotary_pos_emb_list[-1] # the last one is posisiton_ids
|
position_ids = rotary_pos_emb_list[-1] # the last one is posisiton_ids
|
||||||
rotary_pos_emb_list = rotary_pos_emb_list[:-1]
|
inv_freq = rotary_pos_emb_list[-2]
|
||||||
|
rotary_pos_emb_list = rotary_pos_emb_list[:-2]
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states)
|
invalidInputError(len(rotary_pos_emb_list) == 1,
|
||||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
"rotary_pos_emb_list's length cannot be larger than 1")
|
||||||
use_fuse_rope,
|
use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
|
||||||
True,
|
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||||
bsz * q_len)
|
if use_fuse_rope:
|
||||||
if decoding_fast_path:
|
rot_dim = rotary_pos_emb[0].size(-1)
|
||||||
hidden_states = hidden_states.view(1, -1)
|
|
||||||
cache_k, cache_v = layer_past[0], layer_past[1]
|
|
||||||
cache_k = cache_k.transpose(1, 2)
|
|
||||||
cache_v = cache_v.transpose(1, 2)
|
|
||||||
|
|
||||||
kv_seq_len = cache_k.shape[-2]
|
|
||||||
base = self.rope_base
|
|
||||||
if is_enough_kv_cache_room(layer_past, kv_seq_len):
|
|
||||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
cache_k.size(2),
|
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
|
||||||
dtype=cache_k.dtype,
|
|
||||||
device=hidden_states.device)
|
|
||||||
new_cache_k[:] = cache_k
|
|
||||||
new_cache_v[:] = cache_v
|
|
||||||
cache_k = new_cache_k
|
|
||||||
cache_v = new_cache_v
|
|
||||||
|
|
||||||
args = [hidden_states, self.q_proj.weight.data, self.k_proj.weight.data,
|
|
||||||
self.v_proj.weight.data, self.q_proj.bias.data, self.k_proj.bias.data,
|
|
||||||
self.v_proj.bias.data, position_ids, cache_k, cache_v, self.q_proj.weight.qtype,
|
|
||||||
self.v_proj.weight.qtype, kv_seq_len, self.head_dim, base]
|
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
query, key, value = linear_q4_0.forward_qkv_bias(*args)
|
linear_q4_0.rotary_half_inplaced(inv_freq, position_ids,
|
||||||
kv_seq_len += 1
|
query_states[..., :rot_dim], key_states[..., :rot_dim])
|
||||||
query_size, key_size = 1, 1
|
|
||||||
else:
|
else:
|
||||||
query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
rotary_pos_emb = [i[:, -q_len:, :, :].transpose(1, 2) for i in rotary_pos_emb]
|
||||||
key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
query_states = apply_rotary_pos_emb(query_states, rotary_pos_emb)
|
||||||
value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
key_states = apply_rotary_pos_emb(key_states, rotary_pos_emb)
|
||||||
# TODO: speed up
|
|
||||||
# mixed_x_layer = self.c_attn(hidden_states)
|
|
||||||
# query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
|
||||||
|
|
||||||
# query = self._split_heads(query, self.num_heads, self.head_dim)
|
|
||||||
# key = self._split_heads(key, self.num_heads, self.head_dim)
|
|
||||||
# value = self._split_heads(value, self.num_heads, self.head_dim)
|
|
||||||
if len(rotary_pos_emb_list) != 0:
|
|
||||||
cur_len = query.shape[1]
|
|
||||||
if len(rotary_pos_emb_list) == 1:
|
|
||||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
|
||||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
|
||||||
if use_fuse_rope:
|
|
||||||
cos, sin = rotary_pos_emb
|
|
||||||
cos = cos.to(query.dtype)
|
|
||||||
sin = sin.to(query.dtype)
|
|
||||||
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
|
|
||||||
else:
|
|
||||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
|
||||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
|
||||||
# Slice the pos emb for current inference
|
|
||||||
query = apply_rotary_pos_emb(query, q_pos_emb)
|
|
||||||
key = apply_rotary_pos_emb(key, k_pos_emb)
|
|
||||||
else:
|
|
||||||
query_list = []
|
|
||||||
key_list = []
|
|
||||||
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
|
|
||||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
|
||||||
if use_fuse_rope:
|
|
||||||
cos, sin = rotary_pos_emb
|
|
||||||
cos = cos.to(query.dtype)
|
|
||||||
sin = sin.to(query.dtype)
|
|
||||||
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key,
|
|
||||||
sin, cos, "qwen")
|
|
||||||
query_list += [query]
|
|
||||||
key_list += [key]
|
|
||||||
else:
|
|
||||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
|
||||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
|
||||||
# Slice the pos emb for current inference
|
|
||||||
query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
|
|
||||||
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
|
|
||||||
query = torch.cat(query_list, dim=0)
|
|
||||||
key = torch.cat(key_list, dim=0)
|
|
||||||
query_size, key_size = query.size(1), key.size(1)
|
|
||||||
kv_seq_len = key_size if layer_past is None else key_size + layer_past[0].size(1)
|
|
||||||
|
|
||||||
if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training:
|
if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training:
|
||||||
seq_start = kv_seq_len - query_size
|
seq_start = kv_seq_len - q_len
|
||||||
seq_end = kv_seq_len
|
seq_end = kv_seq_len
|
||||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].transpose(1, 2)
|
||||||
query = query * logn_tensor.expand_as(query)
|
query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)
|
||||||
|
|
||||||
if query_size > 1:
|
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
||||||
causal_mask = torch.tril(
|
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states)
|
||||||
torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query.device)
|
key_states, value_states = update_past_key_value(
|
||||||
).view(1, 1, kv_seq_len, kv_seq_len)
|
past_key_value, key_states, value_states,
|
||||||
causal_mask = causal_mask[
|
kv_seq_len, use_quantize_kv, device
|
||||||
:, :, kv_seq_len - query_size:kv_seq_len, :kv_seq_len
|
)
|
||||||
]
|
past_key_value = (key_states.transpose(1, 2),
|
||||||
else:
|
value_states.transpose(1, 2)) if use_cache else None
|
||||||
causal_mask = None
|
|
||||||
|
|
||||||
if layer_past is not None:
|
|
||||||
if not decoding_fast_path:
|
|
||||||
cache_k, cache_v = layer_past[0], layer_past[1]
|
|
||||||
cache_k = cache_k.transpose(1, 2)
|
|
||||||
cache_v = cache_v.transpose(1, 2)
|
|
||||||
if cache_k.stride(1) < kv_seq_len * cache_k.size(3):
|
|
||||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
cache_k.size(2),
|
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
|
||||||
dtype=cache_k.dtype,
|
|
||||||
device=hidden_states.device)
|
|
||||||
new_cache_k[:] = cache_k
|
|
||||||
new_cache_v[:] = cache_v
|
|
||||||
cache_k = new_cache_k
|
|
||||||
cache_v = new_cache_v
|
|
||||||
key_states, value_states = append_kv_cache(cache_k, cache_v,
|
|
||||||
key.transpose(1, 2), value.transpose(1, 2))
|
|
||||||
key = key_states
|
|
||||||
value = value_states
|
|
||||||
elif use_cache:
|
|
||||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
|
||||||
new_key_states, new_value_states = init_kv_cache(bsz,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
kv_seq_len,
|
|
||||||
max_cache_length,
|
|
||||||
dtype=key.dtype,
|
|
||||||
device=hidden_states.device)
|
|
||||||
new_key_states[:] = key.transpose(1, 2)
|
|
||||||
new_value_states[:] = value.transpose(1, 2)
|
|
||||||
key = new_key_states
|
|
||||||
value = new_value_states
|
|
||||||
|
|
||||||
if not decoding_fast_path:
|
|
||||||
query = query.transpose(1, 2)
|
|
||||||
|
|
||||||
|
# IPEX-LLM OPT: sdp
|
||||||
|
attn_weights = None
|
||||||
if not self.training and not hidden_states.requires_grad and \
|
if not self.training and not hidden_states.requires_grad and \
|
||||||
use_flash_attention(query, key):
|
use_flash_attention(query_states, key_states, attention_mask):
|
||||||
attn_output = F.scaled_dot_product_attention(query.to(device, dtype=torch.float16),
|
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
|
||||||
key.to(device, dtype=torch.float16),
|
key_states.to(dtype=torch.float16),
|
||||||
value.to(device, dtype=torch.float16),
|
value_states.to(dtype=torch.float16),
|
||||||
is_causal=True)
|
is_causal=True).to(hidden_states.dtype)
|
||||||
attn_output = attn_output.view(query.shape)
|
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||||
attn_output = attn_output.transpose(1, 2)
|
|
||||||
attn_weights = None
|
|
||||||
elif not self.training and not hidden_states.requires_grad and \
|
|
||||||
use_sdp(q_len, key.shape[2], self.head_dim, query):
|
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
attn_output = linear_q4_0.sdp(query, key, value, attention_mask)
|
if use_quantize_kv:
|
||||||
attn_output = attn_output.view(query.shape)
|
attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states)
|
||||||
attn_output = attn_output.transpose(1, 2)
|
|
||||||
attn_weight = None
|
|
||||||
else:
|
|
||||||
attn_output, attn_weight = self._attn(
|
|
||||||
query.to(key.dtype), key, value, causal_mask, attention_mask, head_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
context_layer = self._merge_heads(
|
|
||||||
attn_output, self.num_heads, self.head_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.c_proj(context_layer).to(original_dtype)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
outputs = (attn_output, (key.transpose(1, 2), value.transpose(1, 2)))
|
|
||||||
else:
|
|
||||||
outputs = (attn_output, None)
|
|
||||||
if output_attentions:
|
|
||||||
outputs += (attn_weight,)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
def qwen_attention_forward_quantized(
|
|
||||||
self,
|
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
|
||||||
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
|
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
use_cache: Optional[bool] = False,
|
|
||||||
):
|
|
||||||
invalidInputError(not self.use_flash_attn and not self.use_cache_quantization,
|
|
||||||
"flash attn and kv_cache quantization are not supported")
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
device = hidden_states.device
|
|
||||||
position_ids = rotary_pos_emb_list[-1] # the last one is posisiton_ids
|
|
||||||
rotary_pos_emb_list = rotary_pos_emb_list[:-1]
|
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states)
|
|
||||||
# qtype_check = decoding_fast_path_qtype_check(self.q_proj)
|
|
||||||
# TODO: use when decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1)
|
|
||||||
decoding_fast_path = False
|
|
||||||
if decoding_fast_path:
|
|
||||||
hidden_states = hidden_states.view(1, -1)
|
|
||||||
tmp_cache_k, tmp_cache_v = init_kv_cache(
|
|
||||||
bsz,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
base = self.rope_base
|
|
||||||
|
|
||||||
args = [hidden_states, self.q_proj.weight.data, self.k_proj.weight.data,
|
|
||||||
self.v_proj.weight.data, self.q_proj.bias.data, self.k_proj.bias.data,
|
|
||||||
self.v_proj.bias.data, position_ids, tmp_cache_k, tmp_cache_v,
|
|
||||||
self.q_proj.weight.qtype, self.v_proj.weight.qtype, 0, self.head_dim, base]
|
|
||||||
import linear_q4_0
|
|
||||||
query, key, value = linear_q4_0.forward_qkv_bias(*args)
|
|
||||||
self.kv_seq_len += 1
|
|
||||||
kv_seq_len = self.kv_seq_len
|
|
||||||
query_size, key_size = 1, 1
|
|
||||||
else:
|
|
||||||
query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
# TODO: speed up
|
|
||||||
# mixed_x_layer = self.c_attn(hidden_states)
|
|
||||||
# query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
|
||||||
|
|
||||||
# query = self._split_heads(query, self.num_heads, self.head_dim)
|
|
||||||
# key = self._split_heads(key, self.num_heads, self.head_dim)
|
|
||||||
# value = self._split_heads(value, self.num_heads, self.head_dim)
|
|
||||||
if rotary_pos_emb_list is not None:
|
|
||||||
cur_len = query.shape[1]
|
|
||||||
if len(rotary_pos_emb_list) == 1:
|
|
||||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
|
||||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
|
||||||
if use_fuse_rope:
|
|
||||||
cos, sin = rotary_pos_emb
|
|
||||||
cos = cos.to(query.dtype)
|
|
||||||
sin = sin.to(query.dtype)
|
|
||||||
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
|
|
||||||
else:
|
|
||||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
|
||||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
|
||||||
# Slice the pos emb for current inference
|
|
||||||
query = apply_rotary_pos_emb(query, q_pos_emb)
|
|
||||||
key = apply_rotary_pos_emb(key, k_pos_emb)
|
|
||||||
else:
|
|
||||||
query_list = []
|
|
||||||
key_list = []
|
|
||||||
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
|
|
||||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
|
||||||
if use_fuse_rope:
|
|
||||||
cos, sin = rotary_pos_emb
|
|
||||||
cos = cos.to(query.dtype)
|
|
||||||
sin = sin.to(query.dtype)
|
|
||||||
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key,
|
|
||||||
sin, cos, "qwen")
|
|
||||||
query_list += [query]
|
|
||||||
key_list += [key]
|
|
||||||
else:
|
|
||||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
|
||||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
|
||||||
# Slice the pos emb for current inference
|
|
||||||
query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
|
|
||||||
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
|
|
||||||
query = torch.cat(query_list, dim=0)
|
|
||||||
key = torch.cat(key_list, dim=0)
|
|
||||||
query_size, key_size = query.size(1), key.size(1)
|
|
||||||
kv_seq_len = key_size if layer_past is None else key_size + layer_past[0].size(1)
|
|
||||||
|
|
||||||
if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training:
|
|
||||||
seq_start = kv_seq_len - query_size
|
|
||||||
seq_end = kv_seq_len
|
|
||||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
|
||||||
query = query * logn_tensor.expand_as(query)
|
|
||||||
|
|
||||||
if query_size > 1:
|
|
||||||
causal_mask = torch.tril(
|
|
||||||
torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query.device)
|
|
||||||
).view(1, 1, kv_seq_len, kv_seq_len)
|
|
||||||
causal_mask = causal_mask[
|
|
||||||
:, :, kv_seq_len - query_size:kv_seq_len, :kv_seq_len
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
causal_mask = None
|
|
||||||
|
|
||||||
if layer_past is None:
|
|
||||||
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
|
||||||
# query, key, value's shape: [bs, num_heads, seq_len, head_dim]
|
|
||||||
|
|
||||||
# save kv seq len for decoding_fast_path
|
|
||||||
self.kv_seq_len = key.shape[-2]
|
|
||||||
# For first token, use original attn
|
|
||||||
attn_output, attn_weight = self._attn(
|
|
||||||
query, key, value, causal_mask, attention_mask, head_mask
|
|
||||||
)
|
|
||||||
if use_cache:
|
|
||||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
|
||||||
query.size(0), self.num_heads, kv_seq_len, self.head_dim,
|
|
||||||
device=query.device
|
|
||||||
)
|
|
||||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
|
||||||
else:
|
|
||||||
if decoding_fast_path:
|
|
||||||
k_cache, v_cache = layer_past[0], layer_past[1]
|
|
||||||
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
|
||||||
else:
|
else:
|
||||||
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states)
|
||||||
k_cache, v_cache = layer_past[0], layer_past[1]
|
|
||||||
|
|
||||||
k_cache = k_cache.transpose(1, 2)
|
|
||||||
v_cache = v_cache.transpose(1, 2)
|
|
||||||
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
|
||||||
|
|
||||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
|
||||||
|
|
||||||
attn_output, attn_weight = core_attn(
|
|
||||||
self, query, key, value, causal_mask, attention_mask, head_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
context_layer = self._merge_heads(
|
|
||||||
attn_output, self.num_heads, self.head_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.c_proj(context_layer)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
outputs = (attn_output, (key.transpose(1, 2), value.transpose(1, 2)))
|
|
||||||
else:
|
else:
|
||||||
outputs = (attn_output, None)
|
if q_len > 1:
|
||||||
if output_attentions:
|
causal_mask = torch.tril(
|
||||||
outputs += (attn_weight,)
|
torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query_states.device)
|
||||||
|
).view(1, 1, kv_seq_len, kv_seq_len)
|
||||||
return outputs
|
causal_mask = causal_mask[
|
||||||
|
:, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len
|
||||||
|
]
|
||||||
def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
|
attention_mask = torch.zeros(causal_mask.shape, dtype=query_states.dtype,
|
||||||
if not use_sdp_fp8(query.size(2), key.size(2), query):
|
device=query_states.device)
|
||||||
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
|
attention_mask.masked_fill_(causal_mask.logical_not(),
|
||||||
key, value = restore_fp8_kv_cache(key, value, query.dtype)
|
torch.finfo(attention_mask.dtype).min)
|
||||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
attention_mask = attention_mask.expand([bsz, -1, -1, -1])
|
||||||
|
|
||||||
if self.scale_attn_weights:
|
|
||||||
if self.use_cache_quantization:
|
|
||||||
size_temp = value[0].size(-1)
|
|
||||||
else:
|
|
||||||
size_temp = value.size(-1)
|
|
||||||
attn_weights = attn_weights / (size_temp ** 0.5)
|
|
||||||
|
|
||||||
mask_value = torch.finfo(attn_weights.dtype).min
|
|
||||||
if causal_mask is not None:
|
|
||||||
attn_weights = torch.where(
|
|
||||||
causal_mask, attn_weights.to(attn_weights.dtype), mask_value
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
if self.softmax_in_fp32:
|
|
||||||
attn_weights = torch.nn.functional.softmax(attn_weights.float(), dim=-1)
|
|
||||||
else:
|
else:
|
||||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
attention_mask = None
|
||||||
|
|
||||||
attn_weights = attn_weights.type(query.dtype)
|
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||||
attn_weights = self.attn_dropout(attn_weights)
|
import linear_q4_0
|
||||||
|
if use_quantize_kv:
|
||||||
|
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
|
||||||
|
attention_mask)
|
||||||
|
else:
|
||||||
|
attn_output = linear_q4_0.sdp(query_states, key_states, value_states,
|
||||||
|
attention_mask)
|
||||||
|
else:
|
||||||
|
if use_quantize_kv:
|
||||||
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
|
query_states.dtype)
|
||||||
|
attn_weights = torch.matmul(query_states,
|
||||||
|
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
if self.softmax_in_fp32:
|
||||||
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||||
|
dtype=torch.float32).to(
|
||||||
|
value_states.dtype)
|
||||||
|
else:
|
||||||
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
if head_mask is not None:
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_weights = attn_weights * head_mask
|
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
|
attn_output = self.c_proj(attn_output)
|
||||||
attn_output = torch.matmul(attn_weights, value)
|
|
||||||
|
if output_attentions:
|
||||||
|
return attn_output, past_key_value, attn_weights
|
||||||
else:
|
else:
|
||||||
import linear_q4_0
|
return attn_output, past_key_value
|
||||||
attn_output = linear_q4_0.sdp_fp8(query, key, value,
|
|
||||||
attention_mask)
|
|
||||||
attn_weights = None
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
|
||||||
|
|
||||||
return attn_output, attn_weights
|
|
||||||
|
|
||||||
|
|
||||||
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
@ -652,9 +305,11 @@ def qwen_model_forward(
|
||||||
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
||||||
ntk_alpha_list.append(ntk_alpha)
|
ntk_alpha_list.append(ntk_alpha)
|
||||||
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
||||||
|
# ipex-llm changes
|
||||||
rotary_pos_emb_list = [
|
rotary_pos_emb_list = [
|
||||||
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
|
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
|
||||||
] + [position_ids]
|
] + [self.rotary_emb.inv_freq.to(self.dtype), position_ids]
|
||||||
|
# ipex-llm changes ends
|
||||||
|
|
||||||
hidden_states = self.drop(hidden_states)
|
hidden_states = self.drop(hidden_states)
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
@ -695,7 +350,7 @@ def qwen_model_forward(
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# bigdl-llm changes
|
# ipex-llm changes
|
||||||
curr_device = block.ln_1.weight.device
|
curr_device = block.ln_1.weight.device
|
||||||
from accelerate.utils.operations import send_to_device
|
from accelerate.utils.operations import send_to_device
|
||||||
if rotary_pos_emb_list is not None:
|
if rotary_pos_emb_list is not None:
|
||||||
|
|
@ -709,7 +364,7 @@ def qwen_model_forward(
|
||||||
if encoder_attention_mask is not None:
|
if encoder_attention_mask is not None:
|
||||||
encoder_attention_mask = send_to_device(encoder_attention_mask,
|
encoder_attention_mask = send_to_device(encoder_attention_mask,
|
||||||
curr_device)
|
curr_device)
|
||||||
# bigdl-llm changes ends
|
# ipex-llm changes ends
|
||||||
|
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|
|
||||||
|
|
@ -188,5 +188,5 @@ class Test_Optimize_Gpu_Model:
|
||||||
# currently only need to compare the output of one self-attention layer.
|
# currently only need to compare the output of one self-attention layer.
|
||||||
layer_norm = "transformer.h.31.ln_1"
|
layer_norm = "transformer.h.31.ln_1"
|
||||||
self_attn = "transformer.h.31.attn"
|
self_attn = "transformer.h.31.attn"
|
||||||
lower_bound = 8e-3
|
lower_bound = 2e-2
|
||||||
self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, self_attn, layer_norm, lower_bound)
|
self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, self_attn, layer_norm, lower_bound)
|
||||||
Loading…
Reference in a new issue