Fuse MOE indexes computation (#10716)
* try moe * use c++ cpu to compute indexes * fix style
This commit is contained in:
parent
70ed9397f9
commit
019293e1b9
1 changed files with 28 additions and 11 deletions
|
|
@ -91,7 +91,34 @@ def mixtral_moeblock_forward(self,
|
||||||
# we cast back to the input dtype
|
# we cast back to the input dtype
|
||||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||||
|
|
||||||
if bs > 1:
|
if bs == 1:
|
||||||
|
selected_experts = selected_experts[0].cpu().tolist()
|
||||||
|
for idx in range(self.top_k):
|
||||||
|
exp_id = selected_experts[idx]
|
||||||
|
expert_layer = self.experts[exp_id]
|
||||||
|
weight = routing_weights[:, idx]
|
||||||
|
if idx == 0:
|
||||||
|
final_hidden_states = expert_layer(hidden_states, weight)
|
||||||
|
else:
|
||||||
|
final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight)
|
||||||
|
elif bs < 256:
|
||||||
|
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim),
|
||||||
|
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
|
import linear_q4_0
|
||||||
|
indexes = linear_q4_0.get_moe_indexes(selected_experts.to(torch.int32).cpu(), 8)
|
||||||
|
for expert_idx in range(self.num_experts):
|
||||||
|
expert_layer = self.experts[expert_idx]
|
||||||
|
idx_list = indexes[0][expert_idx]
|
||||||
|
top_x_list = indexes[1][expert_idx]
|
||||||
|
if len(idx_list) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
top_x = torch.tensor(top_x_list, dtype=torch.long, device=hidden_states.device)
|
||||||
|
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
||||||
|
current_hidden_states = expert_layer(current_state,
|
||||||
|
routing_weights[top_x_list, idx_list, None])
|
||||||
|
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||||
|
else:
|
||||||
final_hidden_states = torch.zeros(
|
final_hidden_states = torch.zeros(
|
||||||
(batch_size * sequence_length, hidden_dim),
|
(batch_size * sequence_length, hidden_dim),
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
|
|
@ -124,16 +151,6 @@ def mixtral_moeblock_forward(self,
|
||||||
# However `index_add_` only support torch tensors for indexing so we'll use
|
# However `index_add_` only support torch tensors for indexing so we'll use
|
||||||
# the `top_x` tensor here.
|
# the `top_x` tensor here.
|
||||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||||
else:
|
|
||||||
selected_experts = selected_experts[0].cpu().tolist()
|
|
||||||
for idx in range(self.top_k):
|
|
||||||
exp_id = selected_experts[idx]
|
|
||||||
expert_layer = self.experts[exp_id]
|
|
||||||
weight = routing_weights[:, idx]
|
|
||||||
if idx == 0:
|
|
||||||
final_hidden_states = expert_layer(hidden_states, weight)
|
|
||||||
else:
|
|
||||||
final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight)
|
|
||||||
|
|
||||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||||
return final_hidden_states, router_logits
|
return final_hidden_states, router_logits
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue