From f6441b4e3dbe97f39db4b0db4ca99b92a795a321 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Tue, 13 May 2025 14:50:59 +0800 Subject: [PATCH] Add moe_softmax_topk (#13157) * add moe_softmax_topk * address comments * update --- python/llm/src/ipex_llm/transformers/models/common.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 8762c297..8117db90 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -365,3 +365,11 @@ def rotary_half_with_cache_inplaced(query_states: torch.Tensor, key_states: torc from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced make_cache_contiguous_inplaced(cos, sin) xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin) + + +def moe_softmax_topk(router_logits: torch.Tensor, top_k: int, norm_topk_prob: bool): + import xe_addons + selected_experts, routing_weights = xe_addons.moe_softmax_topk( + router_logits, top_k, norm_topk_prob + ) + return selected_experts, routing_weights