optimize glm edge again (#12539)

This commit is contained in:
Yishuo Wang 2024-12-13 13:52:39 +08:00 committed by GitHub
parent 6596c18489
commit 15219944b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 15 additions and 1 deletions

View file

@ -41,6 +41,7 @@ from transformers.models.glm.modeling_glm import repeat_kv, apply_rotary_pos_emb
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
@ -94,7 +95,12 @@ def glm_attention_forward(
self.num_key_value_heads], dim=1)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if query_states.device.type == "xpu":
import xe_addons
make_cache_contiguous_inplaced(cos, sin)
xe_addons.rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, True)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
use_quantizekv = isinstance(past_key_value, DynamicFp8Cache)
# sin and cos are specific to RoPE models; cache_position needed for the static cache

View file

@ -493,3 +493,11 @@ def get_q_proj_or_qkv_proj(self):
elif hasattr(self, "qkv_proj"):
proj = self.qkv_proj
return proj
def make_cache_contiguous_inplaced(cos: torch.Tensor, sin: torch.Tensor):
if not cos.is_contiguous():
new_cos = cos.contiguous()
new_sin = sin.contiguous()
cos.set_(new_cos)
sin.set_(new_sin)