optimize glm edge again (#12539)
This commit is contained in:
parent
6596c18489
commit
15219944b8
2 changed files with 15 additions and 1 deletions
|
|
@ -41,6 +41,7 @@ from transformers.models.glm.modeling_glm import repeat_kv, apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
||||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||||
|
from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -94,6 +95,11 @@ def glm_attention_forward(
|
||||||
self.num_key_value_heads], dim=1)
|
self.num_key_value_heads], dim=1)
|
||||||
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
if query_states.device.type == "xpu":
|
||||||
|
import xe_addons
|
||||||
|
make_cache_contiguous_inplaced(cos, sin)
|
||||||
|
xe_addons.rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, True)
|
||||||
|
else:
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
use_quantizekv = isinstance(past_key_value, DynamicFp8Cache)
|
use_quantizekv = isinstance(past_key_value, DynamicFp8Cache)
|
||||||
|
|
|
||||||
|
|
@ -493,3 +493,11 @@ def get_q_proj_or_qkv_proj(self):
|
||||||
elif hasattr(self, "qkv_proj"):
|
elif hasattr(self, "qkv_proj"):
|
||||||
proj = self.qkv_proj
|
proj = self.qkv_proj
|
||||||
return proj
|
return proj
|
||||||
|
|
||||||
|
|
||||||
|
def make_cache_contiguous_inplaced(cos: torch.Tensor, sin: torch.Tensor):
|
||||||
|
if not cos.is_contiguous():
|
||||||
|
new_cos = cos.contiguous()
|
||||||
|
new_sin = sin.contiguous()
|
||||||
|
cos.set_(new_cos)
|
||||||
|
sin.set_(new_sin)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue