add rotary_half_with_cache_inplaced to ipex_llm.transformers.models.common (#13143)
* update * small fix
This commit is contained in:
parent
f2598b119e
commit
f5d9c49a2a
4 changed files with 14 additions and 8 deletions
|
|
@ -357,3 +357,11 @@ def rotary_two_with_cache_inplaced(query_states: torch.Tensor, key_states: torch
|
||||||
import xe_addons
|
import xe_addons
|
||||||
xe_addons.rotary_two_with_cache_inplaced(query_states, key_states,
|
xe_addons.rotary_two_with_cache_inplaced(query_states, key_states,
|
||||||
cos, sin, half_layout)
|
cos, sin, half_layout)
|
||||||
|
|
||||||
|
|
||||||
|
def rotary_half_with_cache_inplaced(query_states: torch.Tensor, key_states: torch.Tensor,
|
||||||
|
cos: torch.Tensor, sin: torch.Tensor):
|
||||||
|
import xe_addons
|
||||||
|
from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced
|
||||||
|
make_cache_contiguous_inplaced(cos, sin)
|
||||||
|
xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
|
||||||
|
|
|
||||||
|
|
@ -162,9 +162,8 @@ def llama_attention_forward(
|
||||||
query_states, key_states)
|
query_states, key_states)
|
||||||
else:
|
else:
|
||||||
# transformers >= 4.46
|
# transformers >= 4.46
|
||||||
cos, sin = position_embeddings
|
from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
|
||||||
make_cache_contiguous_inplaced(cos, sin)
|
rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
|
||||||
xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
|
|
||||||
else:
|
else:
|
||||||
if position_embeddings is None:
|
if position_embeddings is None:
|
||||||
if isinstance(getattr(self.rotary_emb, "cos_cached", None), torch.Tensor):
|
if isinstance(getattr(self.rotary_emb, "cos_cached", None), torch.Tensor):
|
||||||
|
|
|
||||||
|
|
@ -62,8 +62,8 @@ def qwen2_5_omni_attention_forward(
|
||||||
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
if query_states.device.type == "xpu":
|
if query_states.device.type == "xpu":
|
||||||
import xe_addons
|
from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
|
||||||
xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
|
rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
|
||||||
else:
|
else:
|
||||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||||
|
|
|
||||||
|
|
@ -93,9 +93,8 @@ def qwen3_attention_forward(
|
||||||
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
if device.type == "xpu":
|
if device.type == "xpu":
|
||||||
import xe_addons
|
from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced
|
||||||
make_cache_contiguous_inplaced(cos, sin)
|
rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
|
||||||
xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
|
|
||||||
else:
|
else:
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue