Support for MPT rotary embedding (#10208)
This commit is contained in:
parent
5e1fee5e05
commit
c876d9b5ca
1 changed files with 40 additions and 1 deletions
|
|
@ -30,7 +30,9 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
||||||
|
|
||||||
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
|
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
|
||||||
attention_mask=None, is_causal=True, needs_weights=False):
|
attention_mask=None, is_causal=True,
|
||||||
|
needs_weights=False, rotary_emb_w_meta_info=None,
|
||||||
|
**kwargs):
|
||||||
qkv = self.Wqkv(x)
|
qkv = self.Wqkv(x)
|
||||||
if self.clip_qkv:
|
if self.clip_qkv:
|
||||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||||
|
|
@ -40,6 +42,43 @@ def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None
|
||||||
dtype = query.dtype
|
dtype = query.dtype
|
||||||
query = self.q_ln(query).to(dtype)
|
query = self.q_ln(query).to(dtype)
|
||||||
key = self.k_ln(key).to(dtype)
|
key = self.k_ln(key).to(dtype)
|
||||||
|
|
||||||
|
if rotary_emb_w_meta_info is not None:
|
||||||
|
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
|
||||||
|
seq_len = rotary_emb_w_meta_info['seq_len']
|
||||||
|
offset_info = rotary_emb_w_meta_info['offset_info']
|
||||||
|
bsz, seqlen = query.shape[:2]
|
||||||
|
query = query.view(bsz, seqlen, -1, self.head_dim)
|
||||||
|
key = key.view(bsz, seqlen, -1, self.head_dim)
|
||||||
|
|
||||||
|
if rotary_emb_w_meta_info['impl'] == 'dail':
|
||||||
|
value = value.view(bsz, seqlen, -1, self.head_dim)
|
||||||
|
|
||||||
|
kv = torch.stack([key, value], dim=2)
|
||||||
|
query, kv = rotary_emb(query,
|
||||||
|
kv,
|
||||||
|
seqlen_offset=offset_info,
|
||||||
|
max_seqlen=seq_len)
|
||||||
|
[key, value] = torch.unbind(kv, dim=2)
|
||||||
|
|
||||||
|
value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
|
||||||
|
elif rotary_emb_w_meta_info['impl'] == 'hf':
|
||||||
|
(cos, sin) = rotary_emb(value, seq_len)
|
||||||
|
if is_transformers_version_gte('4.36'):
|
||||||
|
query, key = apply_rotary_pos_emb(query,
|
||||||
|
key,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
offset_info,
|
||||||
|
unsqueeze_dim=2)
|
||||||
|
else:
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
query, key = apply_rotary_pos_emb(query, key, cos, sin,
|
||||||
|
offset_info)
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
|
||||||
(context, attn_weights, past_key_value) = \
|
(context, attn_weights, past_key_value) = \
|
||||||
mpt_scaled_multihead_dot_product_attention(query, key, value, self.n_heads,
|
mpt_scaled_multihead_dot_product_attention(query, key, value, self.n_heads,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue