parent
439c834ed3
commit
754b0ffecf
1 changed files with 5 additions and 1 deletions
|
|
@ -1348,7 +1348,11 @@ def llama_attention_forward_4_36_original(
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
# otherwise, use native attention
|
# otherwise, use native attention
|
||||||
if not output_attentions:
|
if query_states.device.type == "xpu":
|
||||||
|
dev_name = torch.xpu.get_device_name(query_states.device.index)
|
||||||
|
else:
|
||||||
|
dev_name = "CPU"
|
||||||
|
if not output_attentions and not dev_name.startswith("Intel(R) Data Center GPU Max"):
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue