LLM: Fix rope of chatglm3 to support speculative decoding on CPU (#9926)

This commit is contained in:
Ruonan Wang 2024-01-18 09:28:10 +08:00 committed by GitHub
parent 18cd1f1432
commit 054952f82f

View file

@ -218,7 +218,8 @@ def chatglm2_attention_forward_8eb45c(
# apply relative positional encoding (rotary embedding) # apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
if len(rotary_pos_emb) == 2: # use_fuse_rope, see chatglm2_model_forward if len(rotary_pos_emb) == 2 and isinstance(rotary_pos_emb, tuple):
# use_fuse_rope, see chatglm2_model_forward
cos, sin = rotary_pos_emb cos, sin = rotary_pos_emb
rot_dim = cos.shape[-1] rot_dim = cos.shape[-1]
query_layer = query_layer.transpose(0, 1) query_layer = query_layer.transpose(0, 1)