diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 90de62ab..60191943 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -507,7 +507,7 @@ def qwen2_mlp_forward( x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, SILU, qtype )) - elif not self.training: + elif x.device.type == "xpu" and not self.training: import xe_addons gate = self.gate_proj(x) up = self.up_proj(x)