Add moe_softmax_topk (#13157)

* add moe_softmax_topk

* address comments

* update
This commit is contained in:
Yina Chen 2025-05-13 14:50:59 +08:00 committed by GitHub
parent aa12f69bbf
commit f6441b4e3d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -365,3 +365,11 @@ def rotary_half_with_cache_inplaced(query_states: torch.Tensor, key_states: torc
from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced
make_cache_contiguous_inplaced(cos, sin)
xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
def moe_softmax_topk(router_logits: torch.Tensor, top_k: int, norm_topk_prob: bool):
import xe_addons
selected_experts, routing_weights = xe_addons.moe_softmax_topk(
router_logits, top_k, norm_topk_prob
)
return selected_experts, routing_weights