use new fuse rope in stablelm family (#11497)

This commit is contained in:
Yishuo Wang 2024-07-03 11:08:26 +08:00 committed by GitHub
parent 18c973dc3e
commit d97c2664ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -141,34 +141,24 @@ def stablelm_attention_forward(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Partial rotary embedding
query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim],
query_states[..., self.rotary_emb.dim:],
)
key_rot, key_pass = (
key_states[..., : self.rotary_emb.dim],
key_states[..., self.rotary_emb.dim:],
)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# [batch_size, num_heads, seq_length, head_dim // config.partial_rotary_factor]
# [batch_size, num_heads, seq_length, head_dim * config.partial_rotary_factor]
rot_dim = self.rotary_emb.dim
if should_use_fuse_rope(hidden_states, position_ids, self.training):
query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot,
key_rot,
sin,
cos,
"stablelm",
position_ids)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states[..., :rot_dim], key_states[..., :rot_dim])
else:
query_rot, query_pass = query_states[..., :rot_dim], query_states[..., rot_dim:]
key_rot, key_pass = key_states[..., :rot_dim], key_states[..., rot_dim:]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_rot, key_rot = apply_rotary_pos_emb(query_rot,
key_rot,
cos,
sin,
position_ids,
"stablelm")
# [batch_size, num_heads, seq_length, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1)
key_states = torch.cat((key_rot, key_pass), dim=-1)
query_states = torch.cat((query_rot, query_pass), dim=-1)
key_states = torch.cat((key_rot, key_pass), dim=-1)
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)