parent
							
								
									6be70283b7
								
							
						
					
					
						commit
						16b2a418be
					
				
					 1 changed files with 9 additions and 4 deletions
				
			
		| 
						 | 
					@ -855,11 +855,13 @@ def llama_attention_selective_batching_forward_4_31(
 | 
				
			||||||
                current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
 | 
					                current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
 | 
				
			||||||
                current_value_states = repeat_kv(current_value_states, self.num_key_value_groups)
 | 
					                current_value_states = repeat_kv(current_value_states, self.num_key_value_groups)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                cache_position = None
 | 
				
			||||||
                current_query_states = query_states[batch: batch + 1, :, :, :]
 | 
					                current_query_states = query_states[batch: batch + 1, :, :, :]
 | 
				
			||||||
                attn_output, attn_weights = native_sdp(current_query_states,
 | 
					                attn_output, attn_weights = native_sdp(current_query_states,
 | 
				
			||||||
                                                       current_key_states,
 | 
					                                                       current_key_states,
 | 
				
			||||||
                                                       current_value_states,
 | 
					                                                       current_value_states,
 | 
				
			||||||
                                                       attention_mask[batch],
 | 
					                                                       attention_mask[batch],
 | 
				
			||||||
 | 
					                                                       cache_position,
 | 
				
			||||||
                                                       1,
 | 
					                                                       1,
 | 
				
			||||||
                                                       1,
 | 
					                                                       1,
 | 
				
			||||||
                                                       current_kv_len,
 | 
					                                                       current_kv_len,
 | 
				
			||||||
| 
						 | 
					@ -901,10 +903,12 @@ def llama_attention_selective_batching_forward_4_31(
 | 
				
			||||||
    if isinstance(attention_mask, list):
 | 
					    if isinstance(attention_mask, list):
 | 
				
			||||||
        # For decoding fast path
 | 
					        # For decoding fast path
 | 
				
			||||||
        attention_mask = attention_mask[0]
 | 
					        attention_mask = attention_mask[0]
 | 
				
			||||||
 | 
					    cache_position = None
 | 
				
			||||||
    attn_output, attn_weights = native_sdp(query_states,
 | 
					    attn_output, attn_weights = native_sdp(query_states,
 | 
				
			||||||
                                           key_states,
 | 
					                                           key_states,
 | 
				
			||||||
                                           value_states,
 | 
					                                           value_states,
 | 
				
			||||||
                                           attention_mask,
 | 
					                                           attention_mask,
 | 
				
			||||||
 | 
					                                           cache_position,
 | 
				
			||||||
                                           bsz,
 | 
					                                           bsz,
 | 
				
			||||||
                                           q_len,
 | 
					                                           q_len,
 | 
				
			||||||
                                           kv_seq_len,
 | 
					                                           kv_seq_len,
 | 
				
			||||||
| 
						 | 
					@ -1445,7 +1449,7 @@ def llama_attention_forward_4_38_original(
 | 
				
			||||||
def native_sdp(query, key, value, attention_mask, cache_position,
 | 
					def native_sdp(query, key, value, attention_mask, cache_position,
 | 
				
			||||||
               bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
 | 
					               bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
 | 
				
			||||||
    if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
 | 
					    if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
 | 
				
			||||||
        return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
 | 
					        return native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_position,
 | 
				
			||||||
                                           bsz, q_len, kv_seq_len, head_dim, num_heads)
 | 
					                                           bsz, q_len, kv_seq_len, head_dim, num_heads)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        attn_weights = torch.matmul(query.to(key.dtype),
 | 
					        attn_weights = torch.matmul(query.to(key.dtype),
 | 
				
			||||||
| 
						 | 
					@ -1502,14 +1506,14 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_positio
 | 
				
			||||||
            if cache_position is not None:
 | 
					            if cache_position is not None:
 | 
				
			||||||
                # for transformers 4.38.0
 | 
					                # for transformers 4.38.0
 | 
				
			||||||
                causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
 | 
					                causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
 | 
				
			||||||
                attn_weights = attn_weights + causal_mask
 | 
					                attn_weights_split = attn_weights_split + causal_mask
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                attn_mask_size = (bsz, 1, q_len, kv_seq_len)
 | 
					                attn_mask_size = (bsz, 1, q_len, kv_seq_len)
 | 
				
			||||||
                if attention_mask.size() != attn_mask_size:
 | 
					                if attention_mask.size() != attn_mask_size:
 | 
				
			||||||
                    invalidInputError(False,
 | 
					                    invalidInputError(False,
 | 
				
			||||||
                                      f"Attention mask should be of size {attn_mask_size}, "
 | 
					                                      f"Attention mask should be of size {attn_mask_size}, "
 | 
				
			||||||
                                      f"but is {attention_mask.size()}")
 | 
					                                      f"but is {attention_mask.size()}")
 | 
				
			||||||
                attn_weights = attn_weights + attention_mask
 | 
					                attn_weights_split = attn_weights_split + attention_mask
 | 
				
			||||||
        attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
 | 
					        attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
 | 
				
			||||||
        attn_outputs.append(torch.matmul(attn_weights_split, v))
 | 
					        attn_outputs.append(torch.matmul(attn_weights_split, v))
 | 
				
			||||||
    attn_output = torch.cat(attn_outputs, dim=1)
 | 
					    attn_output = torch.cat(attn_outputs, dim=1)
 | 
				
			||||||
| 
						 | 
					@ -1767,8 +1771,9 @@ def llama_attention_fast_forward(
 | 
				
			||||||
    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
					    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
				
			||||||
    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
					    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cache_position = None
 | 
				
			||||||
    attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
 | 
					    attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
 | 
				
			||||||
                                           attention_mask,
 | 
					                                           attention_mask, cache_position,
 | 
				
			||||||
                                           bsz, q_len, kv_seq_len,
 | 
					                                           bsz, q_len, kv_seq_len,
 | 
				
			||||||
                                           self.head_dim, self.num_heads, output_attentions)
 | 
					                                           self.head_dim, self.num_heads, output_attentions)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue