support stablm2 12b (#11265)
This commit is contained in:
parent
dbc3c2d72d
commit
42fab480ea
1 changed files with 4 additions and 0 deletions
|
|
@ -131,6 +131,10 @@ def stablelm_attention_forward(
|
|||
query_states, key_states, value_states = qkv.split([self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
self.num_key_value_heads], dim=1)
|
||||
# For stablelm-2-12b's qk per-head norm
|
||||
if getattr(self, "qk_layernorm", False):
|
||||
query_states = self.q_layernorm(query_states)
|
||||
key_states = self.k_layernorm(key_states)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue