add fused rotary pos emb for qwen (#9956)

* add fused rotary pos emb for qwen

* update
This commit is contained in:
Xin Qiu 2024-01-23 10:37:56 +08:00 committed by GitHub
parent 7b1d9ad7c0
commit dacf680294
2 changed files with 37 additions and 10 deletions

View file

@ -41,6 +41,7 @@ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv
append_fp8_kv_cache, restore_fp8_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache
from bigdl.llm.transformers.models.utils import mlp_fusion_check
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
@ -89,10 +90,15 @@ def qwen_attention_forward(
# query, key, value's shape: [bs, seq_len, num_heads, head_dim]
if rotary_pos_emb_list is not None:
use_fuse_rope = query.device.type == "xpu" and not (self.training and query.requires_grad)
cur_len = query.shape[1]
if len(rotary_pos_emb_list) == 1:
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
if use_fuse_rope:
cos, sin = rotary_pos_emb
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
else:
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
@ -103,6 +109,12 @@ def qwen_attention_forward(
key_list = []
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
if use_fuse_rope:
cos, sin = rotary_pos_emb
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
query_list += [query]
key_list += [key]
else:
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference

View file

@ -177,6 +177,21 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
f"{model_family} is not supported.")
def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family):
if q.device.type != "xpu":
invalidInputError(False,
f"only xpu is supported in this function")
import linear_q4_0
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
if model_family in ["qwen"]:
linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
return q_embed, k_embed
else:
invalidInputError(False,
f"{model_family} is not supported.")
def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
# to determinate if is enough kv cache room in transformers==4.36
return past_key_value is not None and len(past_key_value.key_cache) > idx and \