Support for MPT rotary embedding (#10208)
This commit is contained in:
		
							parent
							
								
									5e1fee5e05
								
							
						
					
					
						commit
						c876d9b5ca
					
				
					 1 changed files with 40 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -30,7 +30,9 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
 | 
			
		||||
                                    attention_mask=None, is_causal=True, needs_weights=False):
 | 
			
		||||
                                    attention_mask=None, is_causal=True,
 | 
			
		||||
                                    needs_weights=False, rotary_emb_w_meta_info=None,
 | 
			
		||||
                                    **kwargs):
 | 
			
		||||
    qkv = self.Wqkv(x)
 | 
			
		||||
    if self.clip_qkv:
 | 
			
		||||
        qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
 | 
			
		||||
| 
						 | 
				
			
			@ -40,6 +42,43 @@ def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None
 | 
			
		|||
        dtype = query.dtype
 | 
			
		||||
        query = self.q_ln(query).to(dtype)
 | 
			
		||||
        key = self.k_ln(key).to(dtype)
 | 
			
		||||
 | 
			
		||||
    if rotary_emb_w_meta_info is not None:
 | 
			
		||||
            rotary_emb = rotary_emb_w_meta_info['rotary_emb']
 | 
			
		||||
            seq_len = rotary_emb_w_meta_info['seq_len']
 | 
			
		||||
            offset_info = rotary_emb_w_meta_info['offset_info']
 | 
			
		||||
            bsz, seqlen = query.shape[:2]
 | 
			
		||||
            query = query.view(bsz, seqlen, -1, self.head_dim)
 | 
			
		||||
            key = key.view(bsz, seqlen, -1, self.head_dim)
 | 
			
		||||
 | 
			
		||||
            if rotary_emb_w_meta_info['impl'] == 'dail':
 | 
			
		||||
                value = value.view(bsz, seqlen, -1, self.head_dim)
 | 
			
		||||
 | 
			
		||||
                kv = torch.stack([key, value], dim=2)
 | 
			
		||||
                query, kv = rotary_emb(query,
 | 
			
		||||
                                       kv,
 | 
			
		||||
                                       seqlen_offset=offset_info,
 | 
			
		||||
                                       max_seqlen=seq_len)
 | 
			
		||||
                [key, value] = torch.unbind(kv, dim=2)
 | 
			
		||||
 | 
			
		||||
                value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
 | 
			
		||||
            elif rotary_emb_w_meta_info['impl'] == 'hf':
 | 
			
		||||
                (cos, sin) = rotary_emb(value, seq_len)
 | 
			
		||||
                if is_transformers_version_gte('4.36'):
 | 
			
		||||
                    query, key = apply_rotary_pos_emb(query,
 | 
			
		||||
                                                      key,
 | 
			
		||||
                                                      cos,
 | 
			
		||||
                                                      sin,
 | 
			
		||||
                                                      offset_info,
 | 
			
		||||
                                                      unsqueeze_dim=2)
 | 
			
		||||
                else:
 | 
			
		||||
                    query = query.transpose(1, 2)
 | 
			
		||||
                    key = key.transpose(1, 2)
 | 
			
		||||
                    query, key = apply_rotary_pos_emb(query, key, cos, sin,
 | 
			
		||||
                                                      offset_info)
 | 
			
		||||
                    query = query.transpose(1, 2)
 | 
			
		||||
                    key = key.transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    (context, attn_weights, past_key_value) = \
 | 
			
		||||
        mpt_scaled_multihead_dot_product_attention(query, key, value, self.n_heads,
 | 
			
		||||
                                                   past_key_value=past_key_value,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue