LLM:Fix moe indexs error on cpu (#10791)
This commit is contained in:
parent
cbe7b5753f
commit
14ca42a048
2 changed files with 2 additions and 2 deletions
|
|
@ -102,7 +102,7 @@ def mixtral_moeblock_forward(self,
|
|||
final_hidden_states = expert_layer(hidden_states, weight)
|
||||
else:
|
||||
final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight)
|
||||
elif bs < 256:
|
||||
elif bs < 256 and hidden_states.device.type == 'xpu':
|
||||
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim),
|
||||
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
import linear_q4_0
|
||||
|
|
|
|||
|
|
@ -567,7 +567,7 @@ def qwen2moe_moeblock_forward(self, hidden_states: torch.Tensor):
|
|||
final_hidden_states = expert_layer(hidden_states) * weight
|
||||
else:
|
||||
final_hidden_states = final_hidden_states + expert_layer(hidden_states) * weight
|
||||
elif bs < 256:
|
||||
elif bs < 256 and hidden_states.device.type == 'xpu':
|
||||
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim),
|
||||
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
import linear_q4_0
|
||||
|
|
|
|||
Loading…
Reference in a new issue