Chatglm2 rope optimization on xpu (#9350)
This commit is contained in:
parent
833e4dbc8d
commit
1420e45cc0
2 changed files with 96 additions and 7 deletions
|
|
@ -284,6 +284,7 @@ def _optimize_post(model):
|
||||||
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
|
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
|
||||||
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
|
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
|
||||||
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
||||||
|
from bigdl.llm.transformers.models.chatglm2 import chatglm2_model_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.SelfAttention,
|
module.SelfAttention,
|
||||||
chatglm2_attention_forward_8eb45c
|
chatglm2_attention_forward_8eb45c
|
||||||
|
|
@ -291,6 +292,9 @@ def _optimize_post(model):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.CoreAttention,
|
module.CoreAttention,
|
||||||
core_attn_forward_8eb45c)
|
core_attn_forward_8eb45c)
|
||||||
|
convert_forward(model,
|
||||||
|
module.ChatGLMModel,
|
||||||
|
chatglm2_model_forward)
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.RMSNorm,
|
module.RMSNorm,
|
||||||
chatglm_rms_norm_forward)
|
chatglm_rms_norm_forward)
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,9 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
from typing import Optional, Tuple, List
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -54,7 +55,7 @@ def split_tensor_along_last_dim(
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
||||||
# x: [sq, b, np, hn]
|
# x: [sq, b, np, hn]
|
||||||
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
||||||
rot_dim = rope_cache.shape[-2] * 2
|
rot_dim = rope_cache.shape[-2] * 2
|
||||||
|
|
@ -87,6 +88,77 @@ def chatglm_rms_norm_forward(self, hidden_states):
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def chatglm2_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
position_ids: Optional[torch.Tensor]=None,
|
||||||
|
attention_mask: Optional[torch.BoolTensor]=None,
|
||||||
|
full_attention_mask: Optional[torch.BoolTensor]=None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor]=None,
|
||||||
|
use_cache: Optional[bool]=None,
|
||||||
|
output_hidden_states: Optional[bool]=None,
|
||||||
|
return_dict: Optional[bool]=None,
|
||||||
|
):
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embedding(input_ids)
|
||||||
|
|
||||||
|
if full_attention_mask is None:
|
||||||
|
if (attention_mask is not None and not attention_mask.all()) or (
|
||||||
|
past_key_values and seq_length != 1):
|
||||||
|
full_attention_mask = self.get_masks(input_ids,
|
||||||
|
past_key_values,
|
||||||
|
padding_mask=attention_mask)
|
||||||
|
|
||||||
|
use_fuse_rope = input_ids.device.type == "xpu"
|
||||||
|
use_fuse_rope = use_fuse_rope and not self.training
|
||||||
|
|
||||||
|
# Rotary positional embeddings
|
||||||
|
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||||
|
if position_ids is not None:
|
||||||
|
rotary_pos_emb = rotary_pos_emb[position_ids]
|
||||||
|
else:
|
||||||
|
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
||||||
|
if use_fuse_rope:
|
||||||
|
# Repeat cos sin here, call only once for each token.
|
||||||
|
# Chatglm2's rotary embedding is similar to gptj's, is rotate_every_two.
|
||||||
|
# If put this to attension forward, it will generate too many times.
|
||||||
|
cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1)
|
||||||
|
cos = cos.squeeze(-1)
|
||||||
|
sin = sin.squeeze(-1)
|
||||||
|
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
|
||||||
|
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
|
||||||
|
rotary_pos_emb = (cos, sin)
|
||||||
|
else:
|
||||||
|
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
|
# Run encoder.
|
||||||
|
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
||||||
|
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
||||||
|
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions]
|
||||||
|
if v is not None)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=presents,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def chatglm2_attention_forward_8eb45c(
|
def chatglm2_attention_forward_8eb45c(
|
||||||
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
||||||
):
|
):
|
||||||
|
|
@ -132,12 +204,26 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
||||||
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||||
|
|
||||||
|
cur_length, batch_size = query_layer.shape[0], query_layer.shape[1]
|
||||||
|
|
||||||
# apply relative positional encoding (rotary embedding)
|
# apply relative positional encoding (rotary embedding)
|
||||||
if rotary_pos_emb is not None:
|
if rotary_pos_emb is not None:
|
||||||
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
if len(rotary_pos_emb) == 2: # use_fuse_rope, see chatglm2_model_forward
|
||||||
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
cos, sin = rotary_pos_emb
|
||||||
|
rot_dim = cos.shape[-1]
|
||||||
cur_length, batch_size = query_layer.shape[0], query_layer.shape[1]
|
query_layer = query_layer.transpose(0, 1)
|
||||||
|
key_layer = key_layer.transpose(0, 1)
|
||||||
|
query_layer_cur = query_layer[..., :rot_dim]
|
||||||
|
key_layer_cur = key_layer[..., :rot_dim]
|
||||||
|
# ipex's apply_rotary_embedding can change the origin storage, so query_layer will get
|
||||||
|
# the result directly.
|
||||||
|
torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur)
|
||||||
|
torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur)
|
||||||
|
query_layer = query_layer.transpose(0, 1)
|
||||||
|
key_layer = key_layer.transpose(0, 1)
|
||||||
|
else:
|
||||||
|
query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb)
|
||||||
|
key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb)
|
||||||
|
|
||||||
if self.multi_query_attention:
|
if self.multi_query_attention:
|
||||||
key_length = key_layer.size(0)
|
key_length = key_layer.size(0)
|
||||||
|
|
@ -200,7 +286,6 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
# ==================================
|
# ==================================
|
||||||
# core attention computation
|
# core attention computation
|
||||||
# ==================================
|
# ==================================
|
||||||
|
|
||||||
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
||||||
|
|
||||||
# =================
|
# =================
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue