Disable sdpa (#10814)
This commit is contained in:
parent
57edf2033c
commit
caf75beef8
1 changed files with 7 additions and 20 deletions
|
|
@ -1346,26 +1346,10 @@ def llama_attention_forward_4_36_original(
|
|||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
# otherwise, use native attention
|
||||
if query_states.device.type == "xpu":
|
||||
dev_name = torch.xpu.get_device_name(query_states.device.index)
|
||||
else:
|
||||
dev_name = "CPU"
|
||||
if not output_attentions and not dev_name.startswith("Intel(R) Data Center GPU Max"):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that
|
||||
# does not create a causal mask in case q_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
)
|
||||
else:
|
||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||
attention_mask,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.head_dim, self.num_heads, output_attentions)
|
||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||
attention_mask,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.head_dim, self.num_heads, output_attentions)
|
||||
|
||||
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||
if attn_output.size() != attn_output_size:
|
||||
|
|
@ -1789,6 +1773,9 @@ def llama_model_forward_4_36_internal(
|
|||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# IPEX-LLM modifications:
|
||||
# Disable sdpa for CPU
|
||||
self._use_sdpa = False
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) \
|
||||
|
|
|
|||
Loading…
Reference in a new issue