Fix pvc llama (#10798)

* ifx

* update
This commit is contained in:
Ruonan Wang 2024-04-18 19:54:19 +08:00 committed by Yang Wang
parent 439c834ed3
commit 754b0ffecf

View file

@ -1348,7 +1348,11 @@ def llama_attention_forward_4_36_original(
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
# otherwise, use native attention # otherwise, use native attention
if not output_attentions: if query_states.device.type == "xpu":
dev_name = torch.xpu.get_device_name(query_states.device.index)
else:
dev_name = "CPU"
if not output_attentions and not dev_name.startswith("Intel(R) Data Center GPU Max"):
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,