LLM:Fix moe indexs error on cpu (#10791)

This commit is contained in:
Wang, Jian4 2024-04-18 15:56:52 +08:00 committed by GitHub
parent cbe7b5753f
commit 14ca42a048
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View file

@ -102,7 +102,7 @@ def mixtral_moeblock_forward(self,
final_hidden_states = expert_layer(hidden_states, weight)
else:
final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight)
elif bs < 256:
elif bs < 256 and hidden_states.device.type == 'xpu':
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype, device=hidden_states.device)
import linear_q4_0

View file

@ -567,7 +567,7 @@ def qwen2moe_moeblock_forward(self, hidden_states: torch.Tensor):
final_hidden_states = expert_layer(hidden_states) * weight
else:
final_hidden_states = final_hidden_states + expert_layer(hidden_states) * weight
elif bs < 256:
elif bs < 256 and hidden_states.device.type == 'xpu':
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype, device=hidden_states.device)
import linear_q4_0