diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 8eda2212..134556f9 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1348,7 +1348,11 @@ def llama_attention_forward_4_36_original( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) # otherwise, use native attention - if not output_attentions: + if query_states.device.type == "xpu": + dev_name = torch.xpu.get_device_name(query_states.device.index) + else: + dev_name = "CPU" + if not output_attentions and not dev_name.startswith("Intel(R) Data Center GPU Max"): attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states,