Fuse MOE indexes computation (#10716)

* try moe

* use c++ cpu to compute indexes

* fix style
This commit is contained in:
Yang Wang 2024-04-11 10:12:55 -07:00 committed by GitHub
parent 70ed9397f9
commit 019293e1b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -91,7 +91,34 @@ def mixtral_moeblock_forward(self,
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
if bs > 1:
if bs == 1:
selected_experts = selected_experts[0].cpu().tolist()
for idx in range(self.top_k):
exp_id = selected_experts[idx]
expert_layer = self.experts[exp_id]
weight = routing_weights[:, idx]
if idx == 0:
final_hidden_states = expert_layer(hidden_states, weight)
else:
final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight)
elif bs < 256:
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype, device=hidden_states.device)
import linear_q4_0
indexes = linear_q4_0.get_moe_indexes(selected_experts.to(torch.int32).cpu(), 8)
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx_list = indexes[0][expert_idx]
top_x_list = indexes[1][expert_idx]
if len(idx_list) == 0:
continue
top_x = torch.tensor(top_x_list, dtype=torch.long, device=hidden_states.device)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state,
routing_weights[top_x_list, idx_list, None])
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
else:
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
@ -124,16 +151,6 @@ def mixtral_moeblock_forward(self,
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
else:
selected_experts = selected_experts[0].cpu().tolist()
for idx in range(self.top_k):
exp_id = selected_experts[idx]
expert_layer = self.experts[exp_id]
weight = routing_weights[:, idx]
if idx == 0:
final_hidden_states = expert_layer(hidden_states, weight)
else:
final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight)
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits