enable fused rmsnorm and rope qwen2 (#10163)
* qwen2 * change convert * cleanup
This commit is contained in:
parent
e31210ba00
commit
1f6d5b9f30
3 changed files with 18 additions and 12 deletions
|
|
@ -901,9 +901,6 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.qwen2 import qwen2_model_forward
|
||||
from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward
|
||||
# TODO: add these optimization back
|
||||
# RMSNorm and rotray embedding are disabled for now
|
||||
# as they lead to obvious performance drop for Qwen 1.5
|
||||
convert_forward(model,
|
||||
module.Qwen2Model,
|
||||
qwen2_model_forward)
|
||||
|
|
@ -911,9 +908,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
module.Qwen2Attention,
|
||||
qwen2_attention_forward
|
||||
)
|
||||
# convert_forward(model,
|
||||
# module.Qwen2RMSNorm,
|
||||
# llama_rms_norm_forward)
|
||||
convert_forward(model,
|
||||
module.Qwen2RMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
convert_forward(model,
|
||||
module.Qwen2MLP,
|
||||
llama_mlp_forward)
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ from bigdl.llm.transformers.models.llama import repeat_kv
|
|||
from bigdl.llm.transformers.models.utils import extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||
from bigdl.llm.transformers.kv import DynamicFp8Cache
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
|
||||
|
|
@ -157,8 +158,8 @@ def qwen2_attention_forward_quantized(
|
|||
"with a layer index.")
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
|
||||
sin, cos, "qwen2", position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
|
|
@ -262,8 +263,8 @@ def qwen2_attention_forward_origin(
|
|||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
|
||||
sin, cos, "qwen2", position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# update the number of seen tokens
|
||||
|
|
|
|||
|
|
@ -179,7 +179,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
|
|||
f"{model_family} is not supported.")
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family):
|
||||
def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_ids=None):
|
||||
if q.device.type != "xpu":
|
||||
invalidInputError(False,
|
||||
f"only xpu is supported in this function")
|
||||
|
|
@ -188,10 +188,18 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family):
|
|||
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
||||
if model_family in ["qwen", "mixtral"]:
|
||||
linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
|
||||
return q_embed, k_embed
|
||||
elif model_family in ["qwen2"]:
|
||||
cos = cos.to(q.dtype)
|
||||
sin = sin.to(q.dtype)
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
|
||||
else:
|
||||
invalidInputError(False,
|
||||
f"{model_family} is not supported.")
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
|
||||
|
|
|
|||
Loading…
Reference in a new issue