diff --git a/python/llm/src/ipex_llm/transformers/models/mixtral.py b/python/llm/src/ipex_llm/transformers/models/mixtral.py index 01c420bd..6e03815b 100644 --- a/python/llm/src/ipex_llm/transformers/models/mixtral.py +++ b/python/llm/src/ipex_llm/transformers/models/mixtral.py @@ -102,7 +102,7 @@ def mixtral_moeblock_forward(self, final_hidden_states = expert_layer(hidden_states, weight) else: final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight) - elif bs < 256: + elif bs < 256 and hidden_states.device.type == 'xpu': final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device) import linear_q4_0 diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py index 369eed54..fe1245a4 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py @@ -567,7 +567,7 @@ def qwen2moe_moeblock_forward(self, hidden_states: torch.Tensor): final_hidden_states = expert_layer(hidden_states) * weight else: final_hidden_states = final_hidden_states + expert_layer(hidden_states) * weight - elif bs < 256: + elif bs < 256 and hidden_states.device.type == 'xpu': final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device) import linear_q4_0