Fix several models based on sdp api change (#13075)
* fix baichuan based on sdp api change * fix several models based on api change * fix style
This commit is contained in:
		
							parent
							
								
									7826152f5a
								
							
						
					
					
						commit
						e08c6bd018
					
				
					 3 changed files with 11 additions and 4 deletions
				
			
		| 
						 | 
					@ -326,14 +326,17 @@ def baichuan_attention_forward_13b(
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            attention_mask = attention_mask[None, :, -q_len:, :]
 | 
					            attention_mask = attention_mask[None, :, -q_len:, :]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    head_dim = query_states.shape[-1]
 | 
				
			||||||
 | 
					    scale = 1 / math.sqrt(head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
					    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
				
			||||||
        import xe_addons
 | 
					        import xe_addons
 | 
				
			||||||
        if use_quantize_kv:
 | 
					        if use_quantize_kv:
 | 
				
			||||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
					            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
				
			||||||
                                            attention_mask)
 | 
					                                            attention_mask, scale)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
					            attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
				
			||||||
                                        attention_mask)
 | 
					                                        attention_mask, scale)
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        if use_quantize_kv:
 | 
					        if use_quantize_kv:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -68,7 +68,9 @@ def glm_sdpa(query, key, value, attention_mask=None, is_causal=False):
 | 
				
			||||||
        if use_sdp(query.shape[2], key.shape[2],
 | 
					        if use_sdp(query.shape[2], key.shape[2],
 | 
				
			||||||
                   query.shape[-1], query):
 | 
					                   query.shape[-1], query):
 | 
				
			||||||
            import xe_addons
 | 
					            import xe_addons
 | 
				
			||||||
            attn_output = xe_addons.sdp(query, key, value, attn_bias)
 | 
					            head_dim = query.shape[-1]
 | 
				
			||||||
 | 
					            scale = 1 / math.sqrt(head_dim)
 | 
				
			||||||
 | 
					            attn_output = xe_addons.sdp(query, key, value, attn_bias, scale)
 | 
				
			||||||
            context_layer = attn_output.view(query.shape)
 | 
					            context_layer = attn_output.view(query.shape)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            head_dim = query.size(-1)
 | 
					            head_dim = query.size(-1)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -164,7 +164,9 @@ def qwen_attention_forward_vl(
 | 
				
			||||||
    if not self.training and not hidden_states.requires_grad and \
 | 
					    if not self.training and not hidden_states.requires_grad and \
 | 
				
			||||||
            use_sdp(q_len, key.shape[2], self.head_dim, query):
 | 
					            use_sdp(q_len, key.shape[2], self.head_dim, query):
 | 
				
			||||||
        import xe_addons
 | 
					        import xe_addons
 | 
				
			||||||
        attn_output = xe_addons.sdp(query, key, value, attention_mask)
 | 
					        head_dim = query.shape[-1]
 | 
				
			||||||
 | 
					        scale = 1 / math.sqrt(head_dim)
 | 
				
			||||||
 | 
					        attn_output = xe_addons.sdp(query, key, value, attention_mask, scale)
 | 
				
			||||||
        attn_output = attn_output.view(query.shape)
 | 
					        attn_output = attn_output.view(query.shape)
 | 
				
			||||||
        attn_output = attn_output.transpose(1, 2)
 | 
					        attn_output = attn_output.transpose(1, 2)
 | 
				
			||||||
        attn_weight = None
 | 
					        attn_weight = None
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue