diff --git a/python/llm/src/ipex_llm/transformers/models/mixtral.py b/python/llm/src/ipex_llm/transformers/models/mixtral.py index 9bf3af14..80ddd785 100644 --- a/python/llm/src/ipex_llm/transformers/models/mixtral.py +++ b/python/llm/src/ipex_llm/transformers/models/mixtral.py @@ -91,7 +91,34 @@ def mixtral_moeblock_forward(self, # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) - if bs > 1: + if bs == 1: + selected_experts = selected_experts[0].cpu().tolist() + for idx in range(self.top_k): + exp_id = selected_experts[idx] + expert_layer = self.experts[exp_id] + weight = routing_weights[:, idx] + if idx == 0: + final_hidden_states = expert_layer(hidden_states, weight) + else: + final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight) + elif bs < 256: + final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, device=hidden_states.device) + import linear_q4_0 + indexes = linear_q4_0.get_moe_indexes(selected_experts.to(torch.int32).cpu(), 8) + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx_list = indexes[0][expert_idx] + top_x_list = indexes[1][expert_idx] + if len(idx_list) == 0: + continue + + top_x = torch.tensor(top_x_list, dtype=torch.long, device=hidden_states.device) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state, + routing_weights[top_x_list, idx_list, None]) + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + else: final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, @@ -124,16 +151,6 @@ def mixtral_moeblock_forward(self, # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - else: - selected_experts = selected_experts[0].cpu().tolist() - for idx in range(self.top_k): - exp_id = selected_experts[idx] - expert_layer = self.experts[exp_id] - weight = routing_weights[:, idx] - if idx == 0: - final_hidden_states = expert_layer(hidden_states, weight) - else: - final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits