From 3a9ab8f1ae6f9908ff741a9af7333998f89fd5b7 Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Wed, 3 Apr 2024 15:08:12 +0800 Subject: [PATCH] fix stablelm logits diff (#10636) * fix logits diff * Small fixes --------- Co-authored-by: Yuwen Hu --- python/llm/src/ipex_llm/transformers/convert.py | 2 ++ .../src/ipex_llm/transformers/models/stablelm.py | 14 ++++++++------ .../llm/src/ipex_llm/transformers/models/utils.py | 4 ++-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index e377c874..117f29ce 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -633,6 +633,7 @@ def _optimize_pre(model): del module.c_attn model.apply(split_qkv_proj_func) if model.config.model_type == "stablelm": + # For stablelm-zephyr-3b from ipex_llm.transformers.models.stablelm import merge_qkv model.apply(merge_qkv) @@ -1341,6 +1342,7 @@ def _optimize_post(model, lightweight_bmm=False): module.BertEncoder, encoder_forward) elif model.config.model_type == 'stablelm': + # For stablelm-zephyr-3b modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.stablelm import stablelm_attention_forward diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index b311c740..3e96e4bf 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -48,7 +48,7 @@ from transformers.models.stablelm.modeling_stablelm import StableLmAttention from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ - apply_rotary_pos_emb_no_cache_xpu + apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv @@ -124,13 +124,15 @@ def stablelm_attention_forward( key_states[..., : self.rotary_emb.dim], key_states[..., self.rotary_emb.dim:], ) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) if use_fuse_rope: - query_rot, key_rot = apply_rotary_pos_emb_no_cache_xpu(query_rot, - key_rot, - position_ids, - "stablelm") + query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot, + key_rot, + sin, + cos, + "stablelm", + position_ids) else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 405fdfd0..24168693 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -208,7 +208,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_the q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device) k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral", - "mixtral", "stablelm"]: + "mixtral"]: linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, q_embed, k_embed, rope_theta) return q_embed, k_embed @@ -226,7 +226,7 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_i k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) if model_family in ["qwen", "mixtral"]: linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed) - elif model_family in ["qwen2", "yuan"]: + elif model_family in ["qwen2", "yuan", "stablelm"]: cos = cos.to(q.dtype) sin = sin.to(q.dtype) cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]