use sdp in baichuan2 13b (#11198)
This commit is contained in:
		
							parent
							
								
									9f8074c653
								
							
						
					
					
						commit
						6454655dcc
					
				
					 1 changed files with 13 additions and 12 deletions
				
			
		| 
						 | 
				
			
			@ -204,24 +204,25 @@ def baichuan_attention_forward_13b(
 | 
			
		|||
        else:
 | 
			
		||||
            attention_mask = attention_mask[:, None, -q_len:, :]
 | 
			
		||||
 | 
			
		||||
    if use_quantize_kv and q_len == 1:
 | 
			
		||||
    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        attn_weights = xe_addons.query_key_fp8_matmul(query_states, key_states)
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                            attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                        attention_mask)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
        attn_weights = torch.matmul(query_states,
 | 
			
		||||
                                    key_states.transpose(2, 3))
 | 
			
		||||
    attn_weights = attn_weights / math.sqrt(self.head_dim)
 | 
			
		||||
    if attention_mask is not None:
 | 
			
		||||
        attn_weights = attn_weights + attention_mask
 | 
			
		||||
    attn_weights = attn_weights.to(query_states.dtype)
 | 
			
		||||
    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
 | 
			
		||||
    if use_quantize_kv and q_len == 1:
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        attn_output = xe_addons.attn_value_fp8_matmul(attn_weights, value_states)
 | 
			
		||||
    else:
 | 
			
		||||
                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
        attn_weights = attn_weights.to(query_states.dtype)
 | 
			
		||||
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights.to(dtype=value_states.dtype), value_states)
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue