fix stablelm logits diff (#10636)
* fix logits diff * Small fixes --------- Co-authored-by: Yuwen Hu <yuwen.hu@intel.com>
This commit is contained in:
parent
97c626d76f
commit
3a9ab8f1ae
3 changed files with 12 additions and 8 deletions
|
|
@ -633,6 +633,7 @@ def _optimize_pre(model):
|
|||
del module.c_attn
|
||||
model.apply(split_qkv_proj_func)
|
||||
if model.config.model_type == "stablelm":
|
||||
# For stablelm-zephyr-3b
|
||||
from ipex_llm.transformers.models.stablelm import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
|
||||
|
|
@ -1341,6 +1342,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
module.BertEncoder,
|
||||
encoder_forward)
|
||||
elif model.config.model_type == 'stablelm':
|
||||
# For stablelm-zephyr-3b
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.stablelm import stablelm_attention_forward
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ from transformers.models.stablelm.modeling_stablelm import StableLmAttention
|
|||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
||||
apply_rotary_pos_emb_no_cache_xpu
|
||||
apply_rotary_pos_emb_cache_freq_xpu
|
||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv
|
||||
|
|
@ -124,13 +124,15 @@ def stablelm_attention_forward(
|
|||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim:],
|
||||
)
|
||||
if use_fuse_rope:
|
||||
query_rot, key_rot = apply_rotary_pos_emb_no_cache_xpu(query_rot,
|
||||
key_rot,
|
||||
position_ids,
|
||||
"stablelm")
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
if use_fuse_rope:
|
||||
query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot,
|
||||
key_rot,
|
||||
sin,
|
||||
cos,
|
||||
"stablelm",
|
||||
position_ids)
|
||||
else:
|
||||
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot,
|
||||
key_rot,
|
||||
|
|
|
|||
|
|
@ -208,7 +208,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_the
|
|||
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
||||
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
|
||||
"mixtral", "stablelm"]:
|
||||
"mixtral"]:
|
||||
linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids,
|
||||
q_embed, k_embed, rope_theta)
|
||||
return q_embed, k_embed
|
||||
|
|
@ -226,7 +226,7 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_i
|
|||
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
||||
if model_family in ["qwen", "mixtral"]:
|
||||
linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
|
||||
elif model_family in ["qwen2", "yuan"]:
|
||||
elif model_family in ["qwen2", "yuan", "stablelm"]:
|
||||
cos = cos.to(q.dtype)
|
||||
sin = sin.to(q.dtype)
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
|
|
|
|||
Loading…
Reference in a new issue