add rotary_half_with_cache_inplaced to ipex_llm.transformers.models.common (#13143)

* update

* small fix
This commit is contained in:
Ruonan Wang 2025-05-09 09:20:44 +08:00 committed by GitHub
parent f2598b119e
commit f5d9c49a2a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 14 additions and 8 deletions

View file

@ -357,3 +357,11 @@ def rotary_two_with_cache_inplaced(query_states: torch.Tensor, key_states: torch
import xe_addons
xe_addons.rotary_two_with_cache_inplaced(query_states, key_states,
cos, sin, half_layout)
def rotary_half_with_cache_inplaced(query_states: torch.Tensor, key_states: torch.Tensor,
cos: torch.Tensor, sin: torch.Tensor):
import xe_addons
from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced
make_cache_contiguous_inplaced(cos, sin)
xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)

View file

@ -162,9 +162,8 @@ def llama_attention_forward(
query_states, key_states)
else:
# transformers >= 4.46
cos, sin = position_embeddings
make_cache_contiguous_inplaced(cos, sin)
xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
else:
if position_embeddings is None:
if isinstance(getattr(self.rotary_emb, "cos_cached", None), torch.Tensor):

View file

@ -62,8 +62,8 @@ def qwen2_5_omni_attention_forward(
cos, sin = position_embeddings
if query_states.device.type == "xpu":
import xe_addons
xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
else:
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]

View file

@ -93,9 +93,8 @@ def qwen3_attention_forward(
cos, sin = position_embeddings
if device.type == "xpu":
import xe_addons
make_cache_contiguous_inplaced(cos, sin)
xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)