fix stablelm logits diff (#10636)
* fix logits diff * Small fixes --------- Co-authored-by: Yuwen Hu <yuwen.hu@intel.com>
This commit is contained in:
parent
97c626d76f
commit
3a9ab8f1ae
3 changed files with 12 additions and 8 deletions
|
|
@ -633,6 +633,7 @@ def _optimize_pre(model):
|
||||||
del module.c_attn
|
del module.c_attn
|
||||||
model.apply(split_qkv_proj_func)
|
model.apply(split_qkv_proj_func)
|
||||||
if model.config.model_type == "stablelm":
|
if model.config.model_type == "stablelm":
|
||||||
|
# For stablelm-zephyr-3b
|
||||||
from ipex_llm.transformers.models.stablelm import merge_qkv
|
from ipex_llm.transformers.models.stablelm import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
|
|
||||||
|
|
@ -1341,6 +1342,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module.BertEncoder,
|
module.BertEncoder,
|
||||||
encoder_forward)
|
encoder_forward)
|
||||||
elif model.config.model_type == 'stablelm':
|
elif model.config.model_type == 'stablelm':
|
||||||
|
# For stablelm-zephyr-3b
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.stablelm import stablelm_attention_forward
|
from ipex_llm.transformers.models.stablelm import stablelm_attention_forward
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ from transformers.models.stablelm.modeling_stablelm import StableLmAttention
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
||||||
apply_rotary_pos_emb_no_cache_xpu
|
apply_rotary_pos_emb_cache_freq_xpu
|
||||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||||
from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv
|
from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv
|
||||||
|
|
@ -124,13 +124,15 @@ def stablelm_attention_forward(
|
||||||
key_states[..., : self.rotary_emb.dim],
|
key_states[..., : self.rotary_emb.dim],
|
||||||
key_states[..., self.rotary_emb.dim:],
|
key_states[..., self.rotary_emb.dim:],
|
||||||
)
|
)
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
if use_fuse_rope:
|
if use_fuse_rope:
|
||||||
query_rot, key_rot = apply_rotary_pos_emb_no_cache_xpu(query_rot,
|
query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot,
|
||||||
key_rot,
|
key_rot,
|
||||||
position_ids,
|
sin,
|
||||||
"stablelm")
|
cos,
|
||||||
|
"stablelm",
|
||||||
|
position_ids)
|
||||||
else:
|
else:
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
||||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot,
|
query_rot, key_rot = apply_rotary_pos_emb(query_rot,
|
||||||
key_rot,
|
key_rot,
|
||||||
|
|
|
||||||
|
|
@ -208,7 +208,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_the
|
||||||
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||||
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
||||||
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
|
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
|
||||||
"mixtral", "stablelm"]:
|
"mixtral"]:
|
||||||
linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids,
|
linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids,
|
||||||
q_embed, k_embed, rope_theta)
|
q_embed, k_embed, rope_theta)
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
@ -226,7 +226,7 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_i
|
||||||
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
||||||
if model_family in ["qwen", "mixtral"]:
|
if model_family in ["qwen", "mixtral"]:
|
||||||
linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
|
linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
|
||||||
elif model_family in ["qwen2", "yuan"]:
|
elif model_family in ["qwen2", "yuan", "stablelm"]:
|
||||||
cos = cos.to(q.dtype)
|
cos = cos.to(q.dtype)
|
||||||
sin = sin.to(q.dtype)
|
sin = sin.to(q.dtype)
|
||||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue