use new rotary two in chatglm4 (#11312)
* use new rotary two in chatglm4 * rempve
This commit is contained in:
parent
f1410d6823
commit
1b0c4c8cb8
1 changed files with 32 additions and 65 deletions
|
|
@ -47,10 +47,6 @@ def chatglm4_model_forward(
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
from ipex_llm.transformers.kv import DynamicFp8Cache
|
from ipex_llm.transformers.kv import DynamicFp8Cache
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
# if use_cache and use_quantize_kv_cache(
|
|
||||||
# self.encoder.layers[0].self_attention.query_key_value, input_ids):
|
|
||||||
# if not isinstance(past_key_values, DynamicFp8Cache):
|
|
||||||
# past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
|
||||||
return chatglm4_model_forward_internal(
|
return chatglm4_model_forward_internal(
|
||||||
self=self,
|
self=self,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|
@ -108,25 +104,17 @@ def chatglm4_model_forward_internal(
|
||||||
dtype=torch.int64, device=inputs_embeds.device)
|
dtype=torch.int64, device=inputs_embeds.device)
|
||||||
position_ids = position_ids.repeat(batch_size, 1)
|
position_ids = position_ids.repeat(batch_size, 1)
|
||||||
|
|
||||||
use_fuse_rope = input_ids.device.type == "xpu"
|
if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
|
||||||
use_fuse_rope = use_fuse_rope and not self.training
|
rot_dim = self.rotary_pos_emb.dim
|
||||||
|
base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
|
||||||
# Rotary positional embeddings
|
# We should generate float inv_freq to avoid overflow, as base is too large.
|
||||||
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
|
||||||
if position_ids is not None:
|
dtype=torch.float,
|
||||||
rotary_pos_emb = rotary_pos_emb[position_ids]
|
device=inputs_embeds.device) / rot_dim))
|
||||||
else:
|
self.rotary_pos_emb.register_buffer("inv_freq",
|
||||||
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
inv_freq.to(inputs_embeds.dtype),
|
||||||
if use_fuse_rope:
|
persistent=False)
|
||||||
# Repeat cos sin here, call only once for each token.
|
self.rotary_pos_emb.cached = True
|
||||||
# Chatglm2's rotary embedding is similar to gptj's, is rotate_every_two.
|
|
||||||
# If put this to attension forward, it will generate too many times.
|
|
||||||
cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1)
|
|
||||||
cos = cos.squeeze(-1)
|
|
||||||
sin = sin.squeeze(-1)
|
|
||||||
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
|
|
||||||
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
|
|
||||||
rotary_pos_emb = (cos, sin)
|
|
||||||
|
|
||||||
# `full_attention_mask` is not None only when
|
# `full_attention_mask` is not None only when
|
||||||
# `past_key_values` is not None and `seq_length` > 1
|
# `past_key_values` is not None and `seq_length` > 1
|
||||||
|
|
@ -148,7 +136,7 @@ def chatglm4_model_forward_internal(
|
||||||
|
|
||||||
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
||||||
inputs_embeds, causal_mask,
|
inputs_embeds, causal_mask,
|
||||||
rotary_pos_emb=rotary_pos_emb,
|
rotary_pos_emb=(self.rotary_pos_emb.inv_freq, position_ids),
|
||||||
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
||||||
)
|
)
|
||||||
# ipex-llm changes end
|
# ipex-llm changes end
|
||||||
|
|
@ -172,26 +160,6 @@ def chatglm4_model_forward_internal(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
|
||||||
# x: [b, np, sq, hn]
|
|
||||||
b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
|
||||||
rot_dim = rope_cache.shape[-2] * 2
|
|
||||||
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
|
||||||
# truncate to support variable sizes
|
|
||||||
rope_cache = rope_cache[:, :sq]
|
|
||||||
xshaped = x.reshape(b, np, sq, rot_dim // 2, 2)
|
|
||||||
rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2)
|
|
||||||
x_out2 = torch.stack(
|
|
||||||
[
|
|
||||||
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
|
||||||
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
|
||||||
],
|
|
||||||
-1,
|
|
||||||
)
|
|
||||||
x_out2 = x_out2.flatten(3)
|
|
||||||
return torch.cat((x_out2, x_pass), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def chatglm4_attention_forward(
|
def chatglm4_attention_forward(
|
||||||
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
||||||
):
|
):
|
||||||
|
|
@ -209,34 +177,33 @@ def chatglm4_attention_forward(
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
# [bs, q_len, np * 3 * hn] -> [bsz, n_head, seq_len, head_dim]
|
# [bs, q_len, np * 3 * hn] -> [bsz, n_head, seq_len, head_dim]
|
||||||
qkv = qkv.view(bsz, q_len, n_head + 2 * n_kv_head, head_dim)
|
qkv = qkv.view(bsz, q_len, n_head + 2 * n_kv_head, head_dim)
|
||||||
|
qkv = qkv.transpose(1, 2)
|
||||||
|
|
||||||
query_states, key_states, value_states = qkv.split([n_head,
|
query_states, key_states, value_states = qkv.split([n_head,
|
||||||
n_kv_head,
|
n_kv_head,
|
||||||
n_kv_head], dim=2)
|
n_kv_head], dim=1)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[1]
|
kv_seq_len = key_states.shape[2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[2]
|
kv_seq_len += past_key_value[0].shape[2]
|
||||||
|
|
||||||
if isinstance(rotary_pos_emb, tuple) and len(rotary_pos_emb) == 2:
|
# IPEX-LLM OPT: fuse rope
|
||||||
# use_fuse_rope, see chatglm4_model_forward
|
inv_freq, position_ids = rotary_pos_emb
|
||||||
cos, sin = rotary_pos_emb
|
rot_dim = inv_freq.size(-1) * 2
|
||||||
rot_dim = cos.shape[-1]
|
if should_use_fuse_rope(hidden_states, rotary_pos_emb[1], self.training):
|
||||||
query_layer_cur = query_states[..., :rot_dim]
|
import xe_addons
|
||||||
key_layer_cur = key_states[..., :rot_dim]
|
xe_addons.rotary_two_inplaced(inv_freq, position_ids,
|
||||||
# ipex_llm's apply_rotary_embedding can change the origin storage,
|
query_states[..., :rot_dim], key_states[..., :rot_dim])
|
||||||
# so query_layer will get the result directly.
|
else:
|
||||||
torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur)
|
idx_theta = torch.outer(position_ids[0].float(),
|
||||||
torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur)
|
inv_freq.float()).to(hidden_states.dtype)
|
||||||
query_states = query_states.transpose(1, 2)
|
idx_theta = idx_theta.unsqueeze(0).unsqueeze(0)
|
||||||
key_states = key_states.transpose(1, 2)
|
cos = torch.cos(idx_theta).repeat_interleave(2, -1)
|
||||||
value_states = value_states.transpose(1, 2)
|
sin = torch.sin(idx_theta).repeat_interleave(2, -1)
|
||||||
elif rotary_pos_emb is not None:
|
q_rot, k_rot = apply_rotary_pos_emb(query_states[..., :rot_dim], key_states[..., :rot_dim],
|
||||||
query_states = query_states.transpose(1, 2)
|
cos, sin, position_ids, "chatglm")
|
||||||
key_states = key_states.transpose(1, 2)
|
query_states[..., :rot_dim] = q_rot[...]
|
||||||
value_states = value_states.transpose(1, 2)
|
key_states[..., :rot_dim] = k_rot[...]
|
||||||
query_states = apply_rotary_pos_emb(query_states, rotary_pos_emb)
|
|
||||||
key_states = apply_rotary_pos_emb(key_states, rotary_pos_emb)
|
|
||||||
|
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv
|
# IPEX-LLM OPT: kv cache and quantize kv
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, hidden_states)
|
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, hidden_states)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue