Fuse MOE indexes computation (#10716)
* try moe * use c++ cpu to compute indexes * fix style
This commit is contained in:
		
							parent
							
								
									70ed9397f9
								
							
						
					
					
						commit
						019293e1b9
					
				
					 1 changed files with 28 additions and 11 deletions
				
			
		| 
						 | 
					@ -91,7 +91,34 @@ def mixtral_moeblock_forward(self,
 | 
				
			||||||
    # we cast back to the input dtype
 | 
					    # we cast back to the input dtype
 | 
				
			||||||
    routing_weights = routing_weights.to(hidden_states.dtype)
 | 
					    routing_weights = routing_weights.to(hidden_states.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if bs > 1:
 | 
					    if bs == 1:
 | 
				
			||||||
 | 
					        selected_experts = selected_experts[0].cpu().tolist()
 | 
				
			||||||
 | 
					        for idx in range(self.top_k):
 | 
				
			||||||
 | 
					            exp_id = selected_experts[idx]
 | 
				
			||||||
 | 
					            expert_layer = self.experts[exp_id]
 | 
				
			||||||
 | 
					            weight = routing_weights[:, idx]
 | 
				
			||||||
 | 
					            if idx == 0:
 | 
				
			||||||
 | 
					                final_hidden_states = expert_layer(hidden_states, weight)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight)
 | 
				
			||||||
 | 
					    elif bs < 256:
 | 
				
			||||||
 | 
					        final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim),
 | 
				
			||||||
 | 
					                                          dtype=hidden_states.dtype, device=hidden_states.device)
 | 
				
			||||||
 | 
					        import linear_q4_0
 | 
				
			||||||
 | 
					        indexes = linear_q4_0.get_moe_indexes(selected_experts.to(torch.int32).cpu(), 8)
 | 
				
			||||||
 | 
					        for expert_idx in range(self.num_experts):
 | 
				
			||||||
 | 
					            expert_layer = self.experts[expert_idx]
 | 
				
			||||||
 | 
					            idx_list = indexes[0][expert_idx]
 | 
				
			||||||
 | 
					            top_x_list = indexes[1][expert_idx]
 | 
				
			||||||
 | 
					            if len(idx_list) == 0:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            top_x = torch.tensor(top_x_list, dtype=torch.long, device=hidden_states.device)
 | 
				
			||||||
 | 
					            current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
 | 
				
			||||||
 | 
					            current_hidden_states = expert_layer(current_state,
 | 
				
			||||||
 | 
					                                                 routing_weights[top_x_list, idx_list, None])
 | 
				
			||||||
 | 
					            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
        final_hidden_states = torch.zeros(
 | 
					        final_hidden_states = torch.zeros(
 | 
				
			||||||
            (batch_size * sequence_length, hidden_dim),
 | 
					            (batch_size * sequence_length, hidden_dim),
 | 
				
			||||||
            dtype=hidden_states.dtype,
 | 
					            dtype=hidden_states.dtype,
 | 
				
			||||||
| 
						 | 
					@ -124,16 +151,6 @@ def mixtral_moeblock_forward(self,
 | 
				
			||||||
            # However `index_add_` only support torch tensors for indexing so we'll use
 | 
					            # However `index_add_` only support torch tensors for indexing so we'll use
 | 
				
			||||||
            # the `top_x` tensor here.
 | 
					            # the `top_x` tensor here.
 | 
				
			||||||
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
 | 
					            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        selected_experts = selected_experts[0].cpu().tolist()
 | 
					 | 
				
			||||||
        for idx in range(self.top_k):
 | 
					 | 
				
			||||||
            exp_id = selected_experts[idx]
 | 
					 | 
				
			||||||
            expert_layer = self.experts[exp_id]
 | 
					 | 
				
			||||||
            weight = routing_weights[:, idx]
 | 
					 | 
				
			||||||
            if idx == 0:
 | 
					 | 
				
			||||||
                final_hidden_states = expert_layer(hidden_states, weight)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
 | 
					    final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
 | 
				
			||||||
    return final_hidden_states, router_logits
 | 
					    return final_hidden_states, router_logits
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue