Support for MPT rotary embedding (#10208)

This commit is contained in:
Heyang Sun 2024-02-22 15:16:31 +08:00 committed by GitHub
parent 5e1fee5e05
commit c876d9b5ca

View file

@ -30,7 +30,9 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
attention_mask=None, is_causal=True, needs_weights=False):
attention_mask=None, is_causal=True,
needs_weights=False, rotary_emb_w_meta_info=None,
**kwargs):
qkv = self.Wqkv(x)
if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
@ -40,6 +42,43 @@ def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None
dtype = query.dtype
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)
if rotary_emb_w_meta_info is not None:
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
seq_len = rotary_emb_w_meta_info['seq_len']
offset_info = rotary_emb_w_meta_info['offset_info']
bsz, seqlen = query.shape[:2]
query = query.view(bsz, seqlen, -1, self.head_dim)
key = key.view(bsz, seqlen, -1, self.head_dim)
if rotary_emb_w_meta_info['impl'] == 'dail':
value = value.view(bsz, seqlen, -1, self.head_dim)
kv = torch.stack([key, value], dim=2)
query, kv = rotary_emb(query,
kv,
seqlen_offset=offset_info,
max_seqlen=seq_len)
[key, value] = torch.unbind(kv, dim=2)
value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
elif rotary_emb_w_meta_info['impl'] == 'hf':
(cos, sin) = rotary_emb(value, seq_len)
if is_transformers_version_gte('4.36'):
query, key = apply_rotary_pos_emb(query,
key,
cos,
sin,
offset_info,
unsqueeze_dim=2)
else:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query, key = apply_rotary_pos_emb(query, key, cos, sin,
offset_info)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
(context, attn_weights, past_key_value) = \
mpt_scaled_multihead_dot_product_attention(query, key, value, self.n_heads,
past_key_value=past_key_value,