LLM: update split qkv native sdp. (#10895)
* LLM: update split qkv native sdp. * fix typo.
This commit is contained in:
		
							parent
							
								
									990535b1cf
								
							
						
					
					
						commit
						9752ffe979
					
				
					 2 changed files with 9 additions and 16 deletions
				
			
		| 
						 | 
				
			
			@ -258,16 +258,14 @@ def chatglm2_quantized_attention_forward_8eb45c(
 | 
			
		|||
            query_split = torch.split(query_layer, block_size, dim=1)
 | 
			
		||||
            key_split = torch.split(key, block_size, dim=1)
 | 
			
		||||
            value_split = torch.split(value, block_size, dim=1)
 | 
			
		||||
            context_layer = torch.empty(batch_size, n_head, seq_len,
 | 
			
		||||
                                        head_dim, dtype=key.dtype).to(query_layer.device)
 | 
			
		||||
            idx = 0
 | 
			
		||||
            results = []
 | 
			
		||||
            for q, k, v in zip(query_split, key_split, value_split):
 | 
			
		||||
                if attention_mask is None:
 | 
			
		||||
                    result = F.scaled_dot_product_attention(q, k, v, is_causal=True)
 | 
			
		||||
                else:
 | 
			
		||||
                    result = F.scaled_dot_product_attention(q, k, v, attention_mask)
 | 
			
		||||
                context_layer[:, idx:idx+q.shape[1], :, :] = result
 | 
			
		||||
                idx = idx + q.shape[1]
 | 
			
		||||
                results.append(result)
 | 
			
		||||
            context_layer = torch.cat(results, dim=1)
 | 
			
		||||
        else:
 | 
			
		||||
            if attention_mask is None:
 | 
			
		||||
                context_layer = F.scaled_dot_product_attention(query_layer, key,
 | 
			
		||||
| 
						 | 
				
			
			@ -541,14 +539,11 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
 | 
			
		|||
                query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1)
 | 
			
		||||
                key_split = torch.split(key_layer, block_size, dim=1)
 | 
			
		||||
                value_split = torch.split(value_layer, block_size, dim=1)
 | 
			
		||||
                batch_size, n_head, seq_len, head_dim = query_layer.shape
 | 
			
		||||
                context_layer = torch.empty(batch_size, n_head, seq_len,
 | 
			
		||||
                                            head_dim, dtype=key_layer.dtype).to(query_layer.device)
 | 
			
		||||
                idx = 0
 | 
			
		||||
                results = []
 | 
			
		||||
                for q, k, v in zip(query_split, key_split, value_split):
 | 
			
		||||
                    result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
 | 
			
		||||
                    context_layer[:, idx:idx+q.shape[1], :, :] = result
 | 
			
		||||
                    idx = idx + q.shape[1]
 | 
			
		||||
                    results.append(result)
 | 
			
		||||
                context_layer = torch.cat(results, dim=1)
 | 
			
		||||
            else:
 | 
			
		||||
                context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
 | 
			
		||||
                                                               key_layer,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1423,8 +1423,7 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
 | 
			
		|||
    query_split = torch.split(query.to(key.dtype), block_size, dim=1)
 | 
			
		||||
    key_split = torch.split(key.transpose(2, 3), block_size, dim=1)
 | 
			
		||||
    value_split = torch.split(value, block_size, dim=1)
 | 
			
		||||
    attn_output = torch.empty(bsz, num_heads, q_len, head_dim).to(query.device)
 | 
			
		||||
    idx = 0
 | 
			
		||||
    attn_outputs = []
 | 
			
		||||
    for q, k, v in zip(query_split, key_split, value_split):
 | 
			
		||||
        attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim)
 | 
			
		||||
        block_actual_size = attn_weights_split.size(1)
 | 
			
		||||
| 
						 | 
				
			
			@ -1442,9 +1441,8 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
 | 
			
		|||
                                  f"but is {attention_mask.size()}")
 | 
			
		||||
            attn_weights_split = attn_weights_split + attention_mask
 | 
			
		||||
        attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
 | 
			
		||||
        attn_weights_split = torch.matmul(attn_weights_split, v)
 | 
			
		||||
        attn_output[:, idx:idx+block_actual_size, :, :] = attn_weights_split
 | 
			
		||||
        idx = idx + block_actual_size
 | 
			
		||||
        attn_outputs.append(torch.matmul(attn_weights_split, v))
 | 
			
		||||
    attn_output = torch.cat(attn_outputs, dim=1)
 | 
			
		||||
    return attn_output.to(key.dtype), None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue