From 1f6d5b9f304303213e4ed8fb9c5910d57bf3dfda Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Tue, 20 Feb 2024 08:33:09 +0800 Subject: [PATCH] enable fused rmsnorm and rope qwen2 (#10163) * qwen2 * change convert * cleanup --- python/llm/src/bigdl/llm/transformers/convert.py | 9 +++------ .../llm/src/bigdl/llm/transformers/models/qwen2.py | 9 +++++---- .../llm/src/bigdl/llm/transformers/models/utils.py | 12 ++++++++++-- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index bc886e3c..033ca934 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -901,9 +901,6 @@ def _optimize_post(model, lightweight_bmm=False): module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.qwen2 import qwen2_model_forward from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward - # TODO: add these optimization back - # RMSNorm and rotray embedding are disabled for now - # as they lead to obvious performance drop for Qwen 1.5 convert_forward(model, module.Qwen2Model, qwen2_model_forward) @@ -911,9 +908,9 @@ def _optimize_post(model, lightweight_bmm=False): module.Qwen2Attention, qwen2_attention_forward ) - # convert_forward(model, - # module.Qwen2RMSNorm, - # llama_rms_norm_forward) + convert_forward(model, + module.Qwen2RMSNorm, + llama_rms_norm_forward) convert_forward(model, module.Qwen2MLP, llama_mlp_forward) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen2.py b/python/llm/src/bigdl/llm/transformers/models/qwen2.py index e71a1df6..1b9265c9 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen2.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen2.py @@ -48,6 +48,7 @@ from bigdl.llm.transformers.models.llama import repeat_kv from bigdl.llm.transformers.models.utils import extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36 +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from bigdl.llm.transformers.kv import DynamicFp8Cache from bigdl.llm.utils.common import invalidInputError from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb @@ -157,8 +158,8 @@ def qwen2_attention_forward_quantized( "with a layer index.") kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, + sin, cos, "qwen2", position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models @@ -262,8 +263,8 @@ def qwen2_attention_forward_origin( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, + sin, cos, "qwen2", position_ids) if past_key_value is not None: # update the number of seen tokens diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index ec7b4ad2..6104b510 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -179,7 +179,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family): f"{model_family} is not supported.") -def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family): +def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_ids=None): if q.device.type != "xpu": invalidInputError(False, f"only xpu is supported in this function") @@ -188,10 +188,18 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family): k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) if model_family in ["qwen", "mixtral"]: linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed) - return q_embed, k_embed + elif model_family in ["qwen2"]: + cos = cos.to(q.dtype) + sin = sin.to(q.dtype) + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed) else: invalidInputError(False, f"{model_family} is not supported.") + return q_embed, k_embed def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):