From 42fab480eaa4cc1c4ad87077a3451860a37c363c Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 7 Jun 2024 15:46:00 +0800 Subject: [PATCH] support stablm2 12b (#11265) --- python/llm/src/ipex_llm/transformers/models/stablelm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index 441b49cf..9bef4c29 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -131,6 +131,10 @@ def stablelm_attention_forward( query_states, key_states, value_states = qkv.split([self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=1) + # For stablelm-2-12b's qk per-head norm + if getattr(self, "qk_layernorm", False): + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) kv_seq_len = key_states.shape[-2] if past_key_value is not None: