diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index 441b49cf..9bef4c29 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -131,6 +131,10 @@ def stablelm_attention_forward( query_states, key_states, value_states = qkv.split([self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=1) + # For stablelm-2-12b's qk per-head norm + if getattr(self, "qk_layernorm", False): + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) kv_seq_len = key_states.shape[-2] if past_key_value is not None: