use new fuse rope in stablelm family (#11497)
This commit is contained in:
		
							parent
							
								
									18c973dc3e
								
							
						
					
					
						commit
						d97c2664ce
					
				
					 1 changed files with 10 additions and 20 deletions
				
			
		| 
						 | 
				
			
			@ -141,34 +141,24 @@ def stablelm_attention_forward(
 | 
			
		|||
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
			
		||||
 | 
			
		||||
    # Partial rotary embedding
 | 
			
		||||
    query_rot, query_pass = (
 | 
			
		||||
        query_states[..., : self.rotary_emb.dim],
 | 
			
		||||
        query_states[..., self.rotary_emb.dim:],
 | 
			
		||||
    )
 | 
			
		||||
    key_rot, key_pass = (
 | 
			
		||||
        key_states[..., : self.rotary_emb.dim],
 | 
			
		||||
        key_states[..., self.rotary_emb.dim:],
 | 
			
		||||
    )
 | 
			
		||||
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
    # [batch_size, num_heads, seq_length, head_dim // config.partial_rotary_factor]
 | 
			
		||||
    # [batch_size, num_heads, seq_length, head_dim * config.partial_rotary_factor]
 | 
			
		||||
    rot_dim = self.rotary_emb.dim
 | 
			
		||||
    if should_use_fuse_rope(hidden_states, position_ids, self.training):
 | 
			
		||||
        query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot,
 | 
			
		||||
                                                                 key_rot,
 | 
			
		||||
                                                                 sin,
 | 
			
		||||
                                                                 cos,
 | 
			
		||||
                                                                 "stablelm",
 | 
			
		||||
                                                                 position_ids)
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
			
		||||
                                       query_states[..., :rot_dim], key_states[..., :rot_dim])
 | 
			
		||||
    else:
 | 
			
		||||
        query_rot, query_pass = query_states[..., :rot_dim], query_states[..., rot_dim:]
 | 
			
		||||
        key_rot, key_pass = key_states[..., :rot_dim], key_states[..., rot_dim:]
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        query_rot, key_rot = apply_rotary_pos_emb(query_rot,
 | 
			
		||||
                                                  key_rot,
 | 
			
		||||
                                                  cos,
 | 
			
		||||
                                                  sin,
 | 
			
		||||
                                                  position_ids,
 | 
			
		||||
                                                  "stablelm")
 | 
			
		||||
 | 
			
		||||
    # [batch_size, num_heads, seq_length, head_dim]
 | 
			
		||||
    query_states = torch.cat((query_rot, query_pass), dim=-1)
 | 
			
		||||
    key_states = torch.cat((key_rot, key_pass), dim=-1)
 | 
			
		||||
        query_states = torch.cat((query_rot, query_pass), dim=-1)
 | 
			
		||||
        key_states = torch.cat((key_rot, key_pass), dim=-1)
 | 
			
		||||
 | 
			
		||||
    key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                     self.layer_idx, None)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue