use new fuse rope in stablelm family (#11497)
This commit is contained in:
parent
18c973dc3e
commit
d97c2664ce
1 changed files with 10 additions and 20 deletions
|
|
@ -141,34 +141,24 @@ def stablelm_attention_forward(
|
|||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
# Partial rotary embedding
|
||||
query_rot, query_pass = (
|
||||
query_states[..., : self.rotary_emb.dim],
|
||||
query_states[..., self.rotary_emb.dim:],
|
||||
)
|
||||
key_rot, key_pass = (
|
||||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim:],
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
# [batch_size, num_heads, seq_length, head_dim // config.partial_rotary_factor]
|
||||
# [batch_size, num_heads, seq_length, head_dim * config.partial_rotary_factor]
|
||||
rot_dim = self.rotary_emb.dim
|
||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||
query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot,
|
||||
key_rot,
|
||||
sin,
|
||||
cos,
|
||||
"stablelm",
|
||||
position_ids)
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states[..., :rot_dim], key_states[..., :rot_dim])
|
||||
else:
|
||||
query_rot, query_pass = query_states[..., :rot_dim], query_states[..., rot_dim:]
|
||||
key_rot, key_pass = key_states[..., :rot_dim], key_states[..., rot_dim:]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot,
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
position_ids,
|
||||
"stablelm")
|
||||
|
||||
# [batch_size, num_heads, seq_length, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None)
|
||||
|
|
|
|||
Loading…
Reference in a new issue