diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 8762c297..8117db90 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -365,3 +365,11 @@ def rotary_half_with_cache_inplaced(query_states: torch.Tensor, key_states: torc from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced make_cache_contiguous_inplaced(cos, sin) xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin) + + +def moe_softmax_topk(router_logits: torch.Tensor, top_k: int, norm_topk_prob: bool): + import xe_addons + selected_experts, routing_weights = xe_addons.moe_softmax_topk( + router_logits, top_k, norm_topk_prob + ) + return selected_experts, routing_weights