From 14ca42a048f90352fa11ad1868e98b68e7f211e4 Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:56:52 +0800 Subject: [PATCH] =?UTF-8?q?LLM=EF=BC=9AFix=20moe=20indexs=20error=20on=20c?= =?UTF-8?q?pu=20(#10791)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/llm/src/ipex_llm/transformers/models/mixtral.py | 2 +- python/llm/src/ipex_llm/transformers/models/qwen2_moe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/mixtral.py b/python/llm/src/ipex_llm/transformers/models/mixtral.py index 01c420bd..6e03815b 100644 --- a/python/llm/src/ipex_llm/transformers/models/mixtral.py +++ b/python/llm/src/ipex_llm/transformers/models/mixtral.py @@ -102,7 +102,7 @@ def mixtral_moeblock_forward(self, final_hidden_states = expert_layer(hidden_states, weight) else: final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight) - elif bs < 256: + elif bs < 256 and hidden_states.device.type == 'xpu': final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device) import linear_q4_0 diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py index 369eed54..fe1245a4 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py @@ -567,7 +567,7 @@ def qwen2moe_moeblock_forward(self, hidden_states: torch.Tensor): final_hidden_states = expert_layer(hidden_states) * weight else: final_hidden_states = final_hidden_states + expert_layer(hidden_states) * weight - elif bs < 256: + elif bs < 256 and hidden_states.device.type == 'xpu': final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device) import linear_q4_0