fix stablelm logits diff (#10636)

* fix logits diff

* Small fixes

---------

Co-authored-by: Yuwen Hu <yuwen.hu@intel.com>
This commit is contained in:
Xin Qiu 2024-04-03 15:08:12 +08:00 committed by GitHub
parent 97c626d76f
commit 3a9ab8f1ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 12 additions and 8 deletions

View file

@ -633,6 +633,7 @@ def _optimize_pre(model):
del module.c_attn
model.apply(split_qkv_proj_func)
if model.config.model_type == "stablelm":
# For stablelm-zephyr-3b
from ipex_llm.transformers.models.stablelm import merge_qkv
model.apply(merge_qkv)
@ -1341,6 +1342,7 @@ def _optimize_post(model, lightweight_bmm=False):
module.BertEncoder,
encoder_forward)
elif model.config.model_type == 'stablelm':
# For stablelm-zephyr-3b
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.stablelm import stablelm_attention_forward

View file

@ -48,7 +48,7 @@ from transformers.models.stablelm.modeling_stablelm import StableLmAttention
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
apply_rotary_pos_emb_no_cache_xpu
apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv
@ -124,13 +124,15 @@ def stablelm_attention_forward(
key_states[..., : self.rotary_emb.dim],
key_states[..., self.rotary_emb.dim:],
)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if use_fuse_rope:
query_rot, key_rot = apply_rotary_pos_emb_no_cache_xpu(query_rot,
key_rot,
position_ids,
"stablelm")
query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot,
key_rot,
sin,
cos,
"stablelm",
position_ids)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot,
key_rot,

View file

@ -208,7 +208,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_the
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
"mixtral", "stablelm"]:
"mixtral"]:
linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids,
q_embed, k_embed, rope_theta)
return q_embed, k_embed
@ -226,7 +226,7 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_i
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
if model_family in ["qwen", "mixtral"]:
linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
elif model_family in ["qwen2", "yuan"]:
elif model_family in ["qwen2", "yuan", "stablelm"]:
cos = cos.to(q.dtype)
sin = sin.to(q.dtype)
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]