[LLM] Enable kv_cache optimization for Qwen2 on transformers-v4.37.0 (#10131)

* add support for kv_cache optimization on transformers-v4.37.0

* enable attention forward

* style fix

* disable rotary for now
This commit is contained in:
SONG Ge 2024-02-08 14:20:26 +08:00 committed by GitHub
parent 063dc145ac
commit 3f79128ed7
2 changed files with 49 additions and 23 deletions

View file

@ -897,10 +897,10 @@ def _optimize_post(model, lightweight_bmm=False):
# TODO: add these optimization back
# RMSNorm and rotray embedding are disabled for now
# as they lead to obvious performance drop for Qwen 1.5
# convert_forward(model,
# module.Qwen2Attention,
# qwen2_attention_forward
# )
convert_forward(model,
module.Qwen2Attention,
qwen2_attention_forward
)
# convert_forward(model,
# module.Qwen2RMSNorm,
# llama_rms_norm_forward)

View file

@ -38,18 +38,22 @@
#
import math
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
import torch
import torch.nn as nn
from bigdl.llm.transformers.models.llama import repeat_kv
from bigdl.llm.transformers.models.utils import extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
apply_rotary_pos_emb_no_cache_xpu
apply_rotary_pos_emb_no_cache_xpu, is_enough_kv_cache_room_4_36
from bigdl.llm.utils.common import invalidInputError
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def should_use_fuse_rope(self, query_states, position_ids):
use_fuse_rope = query_states.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
@ -76,6 +80,9 @@ def qwen2_attention_forward(
"Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
@ -98,25 +105,44 @@ def qwen2_attention_forward(
"please make sure to initialize the attention class with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"qwen2")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "qwen2")
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
if use_fuse_rope:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states,
value_states,
self.layer_idx,
cache_kwargs)
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value.seen_tokens += key_states.shape[-2]
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v
key_states, value_states = append_kv_cache(cache_k,
cache_v,
key_states,
value_states)
# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)