diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 517572ab..8762c297 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -357,3 +357,11 @@ def rotary_two_with_cache_inplaced(query_states: torch.Tensor, key_states: torch import xe_addons xe_addons.rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, half_layout) + + +def rotary_half_with_cache_inplaced(query_states: torch.Tensor, key_states: torch.Tensor, + cos: torch.Tensor, sin: torch.Tensor): + import xe_addons + from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced + make_cache_contiguous_inplaced(cos, sin) + xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index ac101a55..74547794 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -162,9 +162,8 @@ def llama_attention_forward( query_states, key_states) else: # transformers >= 4.46 - cos, sin = position_embeddings - make_cache_contiguous_inplaced(cos, sin) - xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin) + from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced + rotary_half_with_cache_inplaced(query_states, key_states, cos, sin) else: if position_embeddings is None: if isinstance(getattr(self.rotary_emb, "cos_cached", None), torch.Tensor): diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_5_omni.py b/python/llm/src/ipex_llm/transformers/models/qwen2_5_omni.py index 5efe9632..46654e79 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_5_omni.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_5_omni.py @@ -62,8 +62,8 @@ def qwen2_5_omni_attention_forward( cos, sin = position_embeddings if query_states.device.type == "xpu": - import xe_addons - xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin) + from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced + rotary_half_with_cache_inplaced(query_states, key_states, cos, sin) else: query_states, key_states = apply_multimodal_rotary_pos_emb( query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] diff --git a/python/llm/src/ipex_llm/transformers/models/qwen3.py b/python/llm/src/ipex_llm/transformers/models/qwen3.py index b29fc0c2..9f448f38 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen3.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen3.py @@ -93,9 +93,8 @@ def qwen3_attention_forward( cos, sin = position_embeddings if device.type == "xpu": - import xe_addons - make_cache_contiguous_inplaced(cos, sin) - xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin) + from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced + rotary_half_with_cache_inplaced(query_states, key_states, cos, sin) else: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)