Qwen fused qkv (#10368)

* fused qkv + rope for qwen

* quantized kv cache

* fix

* update qwen

* fixed quantized qkv

* fix

* meet code review

* update split

* convert.py

* extend when no enough kv

* fix
This commit is contained in:
Xin Qiu 2024-03-12 17:39:00 +08:00 committed by GitHub
parent 741c2bf1df
commit 28c4a8cf5c
2 changed files with 346 additions and 84 deletions

View file

@ -587,6 +587,41 @@ def _optimize_pre(model):
):
from bigdl.llm.transformers.models.bert import merge_qkv
model.apply(merge_qkv)
if model.config.model_type == "qwen":
position_ids = torch.arange(0, model.config.max_position_embeddings)
rope_base = model.config.rotary_emb_base
from accelerate.big_modeling import init_empty_weights
def split_qkv_proj_func(module):
if "QWenAttention" in module.__class__.__name__:
c_attn_weight = module.c_attn.weight.data
c_attn_bias = module.c_attn.bias.data
projection_size = module.projection_size
hid_size = module.hidden_size
with init_empty_weights():
q_proj = torch.nn.Linear(hid_size, projection_size)
k_proj = torch.nn.Linear(hid_size, projection_size)
v_proj = torch.nn.Linear(hid_size, projection_size)
if not model.config.to_dict().get("bigdl_transformers_low_bit", False):
q_proj.weight = torch.nn.Parameter(
c_attn_weight[:projection_size, :], requires_grad=False)
q_proj.bias = torch.nn.Parameter(
c_attn_bias[:projection_size], requires_grad=False)
k_proj.weight = torch.nn.Parameter(
c_attn_weight[projection_size: 2 * projection_size, :], requires_grad=False)
k_proj.bias = torch.nn.Parameter(
c_attn_bias[projection_size: 2 * projection_size], requires_grad=False)
v_proj.weight = torch.nn.Parameter(
c_attn_weight[2 * projection_size:, :], requires_grad=False)
v_proj.bias = torch.nn.Parameter(
c_attn_bias[2 * projection_size:], requires_grad=False)
module.q_proj = q_proj
module.k_proj = k_proj
module.v_proj = v_proj
module.position_ids = position_ids
module.rope_base = rope_base
del module.c_attn
model.apply(split_qkv_proj_func)
return model

View file

@ -66,6 +66,25 @@ def apply_rotary_pos_emb(t, freqs):
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
def should_use_fuse_rope(self, query_states):
use_fuse_rope = query_states.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
return use_fuse_rope
def is_enough_kv_cache_room(layer_past, kv_seq_len=1):
# to determinate if is enough kv cache room in transformers between 4.31 and 4.35
# seq_len for current seq len
# For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim]
if layer_past is None:
return False
else:
cache_k, cache_v = layer_past[0], layer_past[1]
cache_k = cache_k.transpose(1, 2)
cache_v = cache_v.transpose(1, 2)
return cache_k.stride(1) < (kv_seq_len + 1) * cache_k.size(3)
def qwen_attention_forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
@ -77,20 +96,87 @@ def qwen_attention_forward(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
forward_function = qwen_attention_forward_quantized
else:
forward_function = qwen_attention_forward_original
return forward_function(
self,
hidden_states,
rotary_pos_emb_list,
layer_past,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
use_cache,
)
def qwen_attention_forward_original(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
):
invalidInputError(not self.use_flash_attn and not self.use_cache_quantization,
"flash attn and kv_cache quantization are not supported")
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
mixed_x_layer = self.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
use_fuse_rope = should_use_fuse_rope(self, hidden_states)
decoding_fast_path = (use_fuse_rope and bsz * q_len == 1)
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
cache_k, cache_v = layer_past[0], layer_past[1]
cache_k = cache_k.transpose(1, 2)
cache_v = cache_v.transpose(1, 2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
# query, key, value's shape: [bs, seq_len, num_heads, head_dim]
kv_seq_len = cache_k.shape[-2]
self.position_ids = self.position_ids.to(device)
position_ids = self.position_ids[kv_seq_len]
base = self.rope_base
if is_enough_kv_cache_room(layer_past, kv_seq_len):
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=hidden_states.device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
args = [hidden_states, self.q_proj.weight.data, self.k_proj.weight.data,
self.v_proj.weight.data, self.q_proj.bias.data, self.k_proj.bias.data,
self.v_proj.bias.data, position_ids, cache_k, cache_v, self.q_proj.weight.qtype,
self.v_proj.weight.qtype, kv_seq_len, self.head_dim, base]
import linear_q4_0
query, key, value = linear_q4_0.forward_qkv_bias(*args)
kv_seq_len += 1
query_size, key_size = 1, 1
else:
query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
# TODO: speed up
# mixed_x_layer = self.c_attn(hidden_states)
# query, key, value = mixed_x_layer.split(self.split_size, dim=2)
# query = self._split_heads(query, self.num_heads, self.head_dim)
# key = self._split_heads(key, self.num_heads, self.head_dim)
# value = self._split_heads(value, self.num_heads, self.head_dim)
if rotary_pos_emb_list is not None:
use_fuse_rope = query.device.type == "xpu" and not (self.training and query.requires_grad)
cur_len = query.shape[1]
if len(rotary_pos_emb_list) == 1:
rotary_pos_emb = rotary_pos_emb_list[0]
@ -115,7 +201,8 @@ def qwen_attention_forward(
cos, sin = rotary_pos_emb
cos = cos.to(query.dtype)
sin = sin.to(query.dtype)
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key,
sin, cos, "qwen")
query_list += [query]
key_list += [key]
else:
@ -126,7 +213,6 @@ def qwen_attention_forward(
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
query = torch.cat(query_list, dim=0)
key = torch.cat(key_list, dim=0)
query_size, key_size = query.size(1), key.size(1)
kv_seq_len = key_size if layer_past is None else key_size + layer_past[0].size(1)
@ -146,42 +232,12 @@ def qwen_attention_forward(
else:
causal_mask = None
if use_quantize_kv_cache(self.c_attn, hidden_states):
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
# query, key, value's shape: [bs, num_heads, seq_len, head_dim]
if layer_past is None:
# For first token, use original attn
attn_output, attn_weight = self._attn(
query, key, value, causal_mask, attention_mask, head_mask
)
if use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
k_cache, v_cache = init_fp8_kv_cache(
query.size(0), self.num_heads, kv_seq_len, self.head_dim,
device=query.device,
)
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
else:
k_cache, v_cache = layer_past[0], layer_past[1]
k_cache = k_cache.transpose(1, 2)
v_cache = v_cache.transpose(1, 2)
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
attn_output, attn_weight = core_attn(
self, query, key, value, causal_mask, attention_mask, head_mask
)
else:
bsz = key.size(0)
if layer_past is not None:
if not decoding_fast_path:
cache_k, cache_v = layer_past[0], layer_past[1]
cache_k = cache_k.transpose(1, 2)
cache_v = cache_v.transpose(1, 2)
if cache_k.stride(1) < kv_seq_len * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,
self.head_dim,
@ -193,7 +249,6 @@ def qwen_attention_forward(
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v,
key.transpose(1, 2), value.transpose(1, 2))
key = key_states
@ -212,6 +267,7 @@ def qwen_attention_forward(
key = new_key_states
value = new_value_states
if not decoding_fast_path:
query = query.transpose(1, 2)
attn_output, attn_weight = self._attn(
@ -234,6 +290,177 @@ def qwen_attention_forward(
return outputs
def qwen_attention_forward_quantized(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
):
invalidInputError(not self.use_flash_attn and not self.use_cache_quantization,
"flash attn and kv_cache quantization are not supported")
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
use_fuse_rope = should_use_fuse_rope(self, hidden_states)
# TODO: use when decoding_fast_path = (use_fuse_rope and bsz * q_len == 1)
decoding_fast_path = False
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
tmp_cache_k, tmp_cache_v = init_kv_cache(
bsz,
self.num_heads,
self.head_dim,
0,
1,
dtype=hidden_states.dtype,
device=device
)
position_ids = self.position_ids[self.kv_seq_len].to(device)
base = self.rope_base
args = [hidden_states, self.q_proj.weight.data, self.k_proj.weight.data,
self.v_proj.weight.data, self.q_proj.bias.data, self.k_proj.bias.data,
self.v_proj.bias.data, position_ids, tmp_cache_k, tmp_cache_v,
self.q_proj.weight.qtype, self.v_proj.weight.qtype, 0, self.head_dim, base]
import linear_q4_0
query, key, value = linear_q4_0.forward_qkv_bias(*args)
self.kv_seq_len += 1
kv_seq_len = self.kv_seq_len
query_size, key_size = 1, 1
else:
query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
# TODO: speed up
# mixed_x_layer = self.c_attn(hidden_states)
# query, key, value = mixed_x_layer.split(self.split_size, dim=2)
# query = self._split_heads(query, self.num_heads, self.head_dim)
# key = self._split_heads(key, self.num_heads, self.head_dim)
# value = self._split_heads(value, self.num_heads, self.head_dim)
if rotary_pos_emb_list is not None:
cur_len = query.shape[1]
if len(rotary_pos_emb_list) == 1:
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
if use_fuse_rope:
cos, sin = rotary_pos_emb
cos = cos.to(query.dtype)
sin = sin.to(query.dtype)
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
else:
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
else:
query_list = []
key_list = []
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
if use_fuse_rope:
cos, sin = rotary_pos_emb
cos = cos.to(query.dtype)
sin = sin.to(query.dtype)
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key,
sin, cos, "qwen")
query_list += [query]
key_list += [key]
else:
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
query = torch.cat(query_list, dim=0)
key = torch.cat(key_list, dim=0)
query_size, key_size = query.size(1), key.size(1)
kv_seq_len = key_size if layer_past is None else key_size + layer_past[0].size(1)
if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training:
seq_start = kv_seq_len - query_size
seq_end = kv_seq_len
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
query = query * logn_tensor.expand_as(query)
if query_size > 1:
causal_mask = torch.tril(
torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query.device)
).view(1, 1, kv_seq_len, kv_seq_len)
causal_mask = causal_mask[
:, :, kv_seq_len - query_size:kv_seq_len, :kv_seq_len
]
else:
causal_mask = None
if layer_past is None:
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
# query, key, value's shape: [bs, num_heads, seq_len, head_dim]
# save kv seq len for decoding_fast_path
self.kv_seq_len = key.shape[-2]
# For first token, use original attn
attn_output, attn_weight = self._attn(
query, key, value, causal_mask, attention_mask, head_mask
)
if use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
k_cache, v_cache = init_fp8_kv_cache(
query.size(0), self.num_heads, kv_seq_len, self.head_dim,
device=query.device,
)
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
else:
if decoding_fast_path:
k_cache, v_cache = layer_past[0], layer_past[1]
k_cache = k_cache.transpose(1, 2)
v_cache = v_cache.transpose(1, 2)
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
attn_output, attn_weight = core_attn(
self, query, key, value, causal_mask, attention_mask, head_mask
)
else:
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
k_cache, v_cache = layer_past[0], layer_past[1]
k_cache = k_cache.transpose(1, 2)
v_cache = v_cache.transpose(1, 2)
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
attn_output, attn_weight = core_attn(
self, query, key, value, causal_mask, attention_mask, head_mask
)
context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim
)
attn_output = self.c_proj(context_layer)
if use_cache:
outputs = (attn_output, (key.transpose(1, 2), value.transpose(1, 2)))
else:
outputs = (attn_output, None)
if output_attentions:
outputs += (attn_weight,)
return outputs
def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
if query.size(2) != 1 or query.device.type != 'xpu':
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32