diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 4de49ae3..c0d642e2 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -41,6 +41,7 @@ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv append_fp8_kv_cache, restore_fp8_kv_cache from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache from bigdl.llm.transformers.models.utils import mlp_fusion_check +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from bigdl.llm.utils.common import invalidInputError, invalidOperationError from bigdl.llm.ggml.quantize import ggml_tensor_qtype @@ -89,25 +90,36 @@ def qwen_attention_forward( # query, key, value's shape: [bs, seq_len, num_heads, head_dim] if rotary_pos_emb_list is not None: + use_fuse_rope = query.device.type == "xpu" and not (self.training and query.requires_grad) cur_len = query.shape[1] if len(rotary_pos_emb_list) == 1: rotary_pos_emb = rotary_pos_emb_list[0] rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] - rotary_pos_emb = (rotary_pos_emb,) * 2 - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - query = apply_rotary_pos_emb(query, q_pos_emb) - key = apply_rotary_pos_emb(key, k_pos_emb) + if use_fuse_rope: + cos, sin = rotary_pos_emb + query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen") + else: + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query = apply_rotary_pos_emb(query, q_pos_emb) + key = apply_rotary_pos_emb(key, k_pos_emb) else: query_list = [] key_list = [] for i, rotary_pos_emb in enumerate(rotary_pos_emb_list): rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] - rotary_pos_emb = (rotary_pos_emb,) * 2 - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)] - key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)] + if use_fuse_rope: + cos, sin = rotary_pos_emb + query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen") + query_list += [query] + key_list += [key] + else: + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)] + key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)] query = torch.cat(query_list, dim=0) key = torch.cat(key_list, dim=0) diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 5abd1345..ca49bb8d 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -177,6 +177,21 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family): f"{model_family} is not supported.") +def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family): + if q.device.type != "xpu": + invalidInputError(False, + f"only xpu is supported in this function") + import linear_q4_0 + q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device) + k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) + if model_family in ["qwen"]: + linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed) + return q_embed, k_embed + else: + invalidInputError(False, + f"{model_family} is not supported.") + + def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1): # to determinate if is enough kv cache room in transformers==4.36 return past_key_value is not None and len(past_key_value.key_cache) > idx and \