add fused rotary pos emb for qwen (#9956)
* add fused rotary pos emb for qwen * update
This commit is contained in:
parent
7b1d9ad7c0
commit
dacf680294
2 changed files with 37 additions and 10 deletions
|
|
@ -41,6 +41,7 @@ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv
|
|||
append_fp8_kv_cache, restore_fp8_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
|
||||
|
|
@ -89,10 +90,15 @@ def qwen_attention_forward(
|
|||
# query, key, value's shape: [bs, seq_len, num_heads, head_dim]
|
||||
|
||||
if rotary_pos_emb_list is not None:
|
||||
use_fuse_rope = query.device.type == "xpu" and not (self.training and query.requires_grad)
|
||||
cur_len = query.shape[1]
|
||||
if len(rotary_pos_emb_list) == 1:
|
||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||
if use_fuse_rope:
|
||||
cos, sin = rotary_pos_emb
|
||||
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
|
||||
else:
|
||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||
# Slice the pos emb for current inference
|
||||
|
|
@ -103,6 +109,12 @@ def qwen_attention_forward(
|
|||
key_list = []
|
||||
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
|
||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||
if use_fuse_rope:
|
||||
cos, sin = rotary_pos_emb
|
||||
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
|
||||
query_list += [query]
|
||||
key_list += [key]
|
||||
else:
|
||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||
# Slice the pos emb for current inference
|
||||
|
|
|
|||
|
|
@ -177,6 +177,21 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
|
|||
f"{model_family} is not supported.")
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family):
|
||||
if q.device.type != "xpu":
|
||||
invalidInputError(False,
|
||||
f"only xpu is supported in this function")
|
||||
import linear_q4_0
|
||||
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
||||
if model_family in ["qwen"]:
|
||||
linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
|
||||
return q_embed, k_embed
|
||||
else:
|
||||
invalidInputError(False,
|
||||
f"{model_family} is not supported.")
|
||||
|
||||
|
||||
def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
|
||||
# to determinate if is enough kv cache room in transformers==4.36
|
||||
return past_key_value is not None and len(past_key_value.key_cache) > idx and \
|
||||
|
|
|
|||
Loading…
Reference in a new issue