Fuse MOE indexes computation (#10716)
* try moe * use c++ cpu to compute indexes * fix style
This commit is contained in:
parent
70ed9397f9
commit
019293e1b9
1 changed files with 28 additions and 11 deletions
|
|
@ -91,7 +91,34 @@ def mixtral_moeblock_forward(self,
|
|||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
if bs > 1:
|
||||
if bs == 1:
|
||||
selected_experts = selected_experts[0].cpu().tolist()
|
||||
for idx in range(self.top_k):
|
||||
exp_id = selected_experts[idx]
|
||||
expert_layer = self.experts[exp_id]
|
||||
weight = routing_weights[:, idx]
|
||||
if idx == 0:
|
||||
final_hidden_states = expert_layer(hidden_states, weight)
|
||||
else:
|
||||
final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight)
|
||||
elif bs < 256:
|
||||
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim),
|
||||
dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
import linear_q4_0
|
||||
indexes = linear_q4_0.get_moe_indexes(selected_experts.to(torch.int32).cpu(), 8)
|
||||
for expert_idx in range(self.num_experts):
|
||||
expert_layer = self.experts[expert_idx]
|
||||
idx_list = indexes[0][expert_idx]
|
||||
top_x_list = indexes[1][expert_idx]
|
||||
if len(idx_list) == 0:
|
||||
continue
|
||||
|
||||
top_x = torch.tensor(top_x_list, dtype=torch.long, device=hidden_states.device)
|
||||
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
||||
current_hidden_states = expert_layer(current_state,
|
||||
routing_weights[top_x_list, idx_list, None])
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
else:
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim),
|
||||
dtype=hidden_states.dtype,
|
||||
|
|
@ -124,16 +151,6 @@ def mixtral_moeblock_forward(self,
|
|||
# However `index_add_` only support torch tensors for indexing so we'll use
|
||||
# the `top_x` tensor here.
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
else:
|
||||
selected_experts = selected_experts[0].cpu().tolist()
|
||||
for idx in range(self.top_k):
|
||||
exp_id = selected_experts[idx]
|
||||
expert_layer = self.experts[exp_id]
|
||||
weight = routing_weights[:, idx]
|
||||
if idx == 0:
|
||||
final_hidden_states = expert_layer(hidden_states, weight)
|
||||
else:
|
||||
final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight)
|
||||
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
|
|
|
|||
Loading…
Reference in a new issue