From 754b0ffecfebec3be05b26b09a59a51a19064725 Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Thu, 18 Apr 2024 19:54:19 +0800 Subject: [PATCH] Fix pvc llama (#10798) * ifx * update --- python/llm/src/ipex_llm/transformers/models/llama.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 8eda2212..134556f9 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1348,7 +1348,11 @@ def llama_attention_forward_4_36_original( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) # otherwise, use native attention - if not output_attentions: + if query_states.device.type == "xpu": + dev_name = torch.xpu.get_device_name(query_states.device.index) + else: + dev_name = "CPU" + if not output_attentions and not dev_name.startswith("Intel(R) Data Center GPU Max"): attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states,