Support for MPT rotary embedding (#10208)
This commit is contained in:
parent
5e1fee5e05
commit
c876d9b5ca
1 changed files with 40 additions and 1 deletions
|
|
@ -30,7 +30,9 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
|||
|
||||
|
||||
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
|
||||
attention_mask=None, is_causal=True, needs_weights=False):
|
||||
attention_mask=None, is_causal=True,
|
||||
needs_weights=False, rotary_emb_w_meta_info=None,
|
||||
**kwargs):
|
||||
qkv = self.Wqkv(x)
|
||||
if self.clip_qkv:
|
||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||
|
|
@ -40,6 +42,43 @@ def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None
|
|||
dtype = query.dtype
|
||||
query = self.q_ln(query).to(dtype)
|
||||
key = self.k_ln(key).to(dtype)
|
||||
|
||||
if rotary_emb_w_meta_info is not None:
|
||||
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
|
||||
seq_len = rotary_emb_w_meta_info['seq_len']
|
||||
offset_info = rotary_emb_w_meta_info['offset_info']
|
||||
bsz, seqlen = query.shape[:2]
|
||||
query = query.view(bsz, seqlen, -1, self.head_dim)
|
||||
key = key.view(bsz, seqlen, -1, self.head_dim)
|
||||
|
||||
if rotary_emb_w_meta_info['impl'] == 'dail':
|
||||
value = value.view(bsz, seqlen, -1, self.head_dim)
|
||||
|
||||
kv = torch.stack([key, value], dim=2)
|
||||
query, kv = rotary_emb(query,
|
||||
kv,
|
||||
seqlen_offset=offset_info,
|
||||
max_seqlen=seq_len)
|
||||
[key, value] = torch.unbind(kv, dim=2)
|
||||
|
||||
value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
|
||||
elif rotary_emb_w_meta_info['impl'] == 'hf':
|
||||
(cos, sin) = rotary_emb(value, seq_len)
|
||||
if is_transformers_version_gte('4.36'):
|
||||
query, key = apply_rotary_pos_emb(query,
|
||||
key,
|
||||
cos,
|
||||
sin,
|
||||
offset_info,
|
||||
unsqueeze_dim=2)
|
||||
else:
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
query, key = apply_rotary_pos_emb(query, key, cos, sin,
|
||||
offset_info)
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
|
||||
(context, attn_weights, past_key_value) = \
|
||||
mpt_scaled_multihead_dot_product_attention(query, key, value, self.n_heads,
|
||||
past_key_value=past_key_value,
|
||||
|
|
|
|||
Loading…
Reference in a new issue