use new fuse rope in stablelm family (#11497)
This commit is contained in:
parent
18c973dc3e
commit
d97c2664ce
1 changed files with 10 additions and 20 deletions
|
|
@ -141,32 +141,22 @@ def stablelm_attention_forward(
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
# Partial rotary embedding
|
# Partial rotary embedding
|
||||||
query_rot, query_pass = (
|
# [batch_size, num_heads, seq_length, head_dim * config.partial_rotary_factor]
|
||||||
query_states[..., : self.rotary_emb.dim],
|
rot_dim = self.rotary_emb.dim
|
||||||
query_states[..., self.rotary_emb.dim:],
|
|
||||||
)
|
|
||||||
key_rot, key_pass = (
|
|
||||||
key_states[..., : self.rotary_emb.dim],
|
|
||||||
key_states[..., self.rotary_emb.dim:],
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
# [batch_size, num_heads, seq_length, head_dim // config.partial_rotary_factor]
|
|
||||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||||
query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot,
|
import xe_addons
|
||||||
key_rot,
|
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||||
sin,
|
query_states[..., :rot_dim], key_states[..., :rot_dim])
|
||||||
cos,
|
|
||||||
"stablelm",
|
|
||||||
position_ids)
|
|
||||||
else:
|
else:
|
||||||
|
query_rot, query_pass = query_states[..., :rot_dim], query_states[..., rot_dim:]
|
||||||
|
key_rot, key_pass = key_states[..., :rot_dim], key_states[..., rot_dim:]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot,
|
query_rot, key_rot = apply_rotary_pos_emb(query_rot,
|
||||||
key_rot,
|
key_rot,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
position_ids,
|
position_ids,
|
||||||
"stablelm")
|
"stablelm")
|
||||||
|
|
||||||
# [batch_size, num_heads, seq_length, head_dim]
|
|
||||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||||
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue