enable fused rmsnorm and rope qwen2 (#10163)

* qwen2

* change convert

* cleanup
This commit is contained in:
Xin Qiu 2024-02-20 08:33:09 +08:00 committed by GitHub
parent e31210ba00
commit 1f6d5b9f30
3 changed files with 18 additions and 12 deletions

View file

@ -901,9 +901,6 @@ def _optimize_post(model, lightweight_bmm=False):
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.qwen2 import qwen2_model_forward
from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward
# TODO: add these optimization back
# RMSNorm and rotray embedding are disabled for now
# as they lead to obvious performance drop for Qwen 1.5
convert_forward(model,
module.Qwen2Model,
qwen2_model_forward)
@ -911,9 +908,9 @@ def _optimize_post(model, lightweight_bmm=False):
module.Qwen2Attention,
qwen2_attention_forward
)
# convert_forward(model,
# module.Qwen2RMSNorm,
# llama_rms_norm_forward)
convert_forward(model,
module.Qwen2RMSNorm,
llama_rms_norm_forward)
convert_forward(model,
module.Qwen2MLP,
llama_mlp_forward)

View file

@ -48,6 +48,7 @@ from bigdl.llm.transformers.models.llama import repeat_kv
from bigdl.llm.transformers.models.utils import extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from bigdl.llm.transformers.kv import DynamicFp8Cache
from bigdl.llm.utils.common import invalidInputError
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
@ -157,8 +158,8 @@ def qwen2_attention_forward_quantized(
"with a layer index.")
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
sin, cos, "qwen2", position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
@ -262,8 +263,8 @@ def qwen2_attention_forward_origin(
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
sin, cos, "qwen2", position_ids)
if past_key_value is not None:
# update the number of seen tokens

View file

@ -179,7 +179,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
f"{model_family} is not supported.")
def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family):
def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_ids=None):
if q.device.type != "xpu":
invalidInputError(False,
f"only xpu is supported in this function")
@ -188,10 +188,18 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family):
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
if model_family in ["qwen", "mixtral"]:
linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
return q_embed, k_embed
elif model_family in ["qwen2"]:
cos = cos.to(q.dtype)
sin = sin.to(q.dtype)
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
else:
invalidInputError(False,
f"{model_family} is not supported.")
return q_embed, k_embed
def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):