From 16b2a418be6a262645c5d579a6e7dd51767347e0 Mon Sep 17 00:00:00 2001 From: SONG Ge <38711238+sgwhat@users.noreply.github.com> Date: Thu, 16 May 2024 17:15:37 +0800 Subject: [PATCH] hotfix native_sdp ut (#11046) * hotfix native_sdp * update --- .../llm/src/ipex_llm/transformers/models/llama.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 92333c34..931ec210 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -855,11 +855,13 @@ def llama_attention_selective_batching_forward_4_31( current_key_states = repeat_kv(current_key_states, self.num_key_value_groups) current_value_states = repeat_kv(current_value_states, self.num_key_value_groups) + cache_position = None current_query_states = query_states[batch: batch + 1, :, :, :] attn_output, attn_weights = native_sdp(current_query_states, current_key_states, current_value_states, attention_mask[batch], + cache_position, 1, 1, current_kv_len, @@ -901,10 +903,12 @@ def llama_attention_selective_batching_forward_4_31( if isinstance(attention_mask, list): # For decoding fast path attention_mask = attention_mask[0] + cache_position = None attn_output, attn_weights = native_sdp(query_states, key_states, value_states, attention_mask, + cache_position, bsz, q_len, kv_seq_len, @@ -1445,7 +1449,7 @@ def llama_attention_forward_4_38_original( def native_sdp(query, key, value, attention_mask, cache_position, bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions): if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions): - return native_sdp_split_qkv_tensor(query, key, value, attention_mask, + return native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_position, bsz, q_len, kv_seq_len, head_dim, num_heads) else: attn_weights = torch.matmul(query.to(key.dtype), @@ -1502,14 +1506,14 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_positio if cache_position is not None: # for transformers 4.38.0 causal_mask = attention_mask[:, :, cache_position, : kv_seq_len] - attn_weights = attn_weights + causal_mask + attn_weights_split = attn_weights_split + causal_mask else: attn_mask_size = (bsz, 1, q_len, kv_seq_len) if attention_mask.size() != attn_mask_size: invalidInputError(False, f"Attention mask should be of size {attn_mask_size}, " f"but is {attention_mask.size()}") - attn_weights = attn_weights + attention_mask + attn_weights_split = attn_weights_split + attention_mask attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1) attn_outputs.append(torch.matmul(attn_weights_split, v)) attn_output = torch.cat(attn_outputs, dim=1) @@ -1767,8 +1771,9 @@ def llama_attention_fast_forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + cache_position = None attn_output, attn_weights = native_sdp(query_states, key_states, value_states, - attention_mask, + attention_mask, cache_position, bsz, q_len, kv_seq_len, self.head_dim, self.num_heads, output_attentions)