optimize yuan 2.0 performance (#10244)

This commit is contained in:
Yishuo Wang 2024-02-26 17:20:10 +08:00 committed by GitHub
parent 6c74b99a28
commit a47989c860
2 changed files with 57 additions and 10 deletions

View file

@ -539,6 +539,31 @@ def _optimize_pre(model):
if model.lm_head.weight.data.device != "meta":
norm_weight = nn.functional.normalize(lm_head_weight_data)
model.lm_head.weight.data = norm_weight
# for yuan 2.0
if model.config.model_type == "yuan":
def merge_qk_proj_func(module):
if "YuanAttention" in module.__class__.__name__:
q_weight = module.q_proj.weight.data
k_weight = module.k_proj.weight.data
num_heads = module.num_heads
head_dim = module.head_dim
hidden_size = module.hidden_size
merged_qk_proj = torch.nn.Linear(0, 0, False)
weight = torch.cat([
q_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
k_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
q_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
k_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
], dim=0).view(num_heads * head_dim * 2, hidden_size)
merged_qk_proj.weight = torch.nn.Parameter(weight, requires_grad=False)
merged_qk_proj.in_features = hidden_size
merged_qk_proj.out_features = num_heads * head_dim * 2
module.merged_qk_proj = merged_qk_proj
del module.q_proj
del module.k_proj
model.apply(merge_qk_proj_func)
return model
@ -1158,6 +1183,7 @@ def _optimize_post(model, lightweight_bmm=False):
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.yuan import yuan_attention_forward
from bigdl.llm.transformers.models.yuan import yuan_mlp_forward
from bigdl.llm.transformers.models.yuan import yuan_localized_filtering_forward
convert_forward(model,
module.YuanAttention,
yuan_attention_forward
@ -1166,4 +1192,7 @@ def _optimize_post(model, lightweight_bmm=False):
module.YuanMLP,
yuan_mlp_forward
)
convert_forward(model,
module.LocalizedFiltering,
yuan_localized_filtering_forward)
return model

View file

@ -50,6 +50,29 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
return use_fuse_rope
def yuan_localized_filtering_forward(
self,
inputs: torch.Tensor,
before_hidden_states: torch.Tensor,
):
if self.conv1.weight.dtype != torch.half:
self.half()
inputs = inputs.half()
if before_hidden_states is not None:
before_hidden_states = before_hidden_states.half()
invalidInputError(self.lf_conv2d_num_pad == 1, "padding must be 1")
if self.training:
lf_output = self._train_forward(inputs)
else:
lf_output = self._inference_forward(inputs, before_hidden_states)
lf_output = lf_output.to(inputs.dtype)
return lf_output
def yuan_mlp_forward(
self,
x: torch.Tensor,
@ -132,8 +155,8 @@ def yuan_attention_forward(
inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :]
else:
hidden_states_tmp = before_hidden_states[:, -1:, :]
inference_hidden_states_memory = \
copy.deepcopy(torch.cat((hidden_states_tmp, hidden_states), dim=1))
inference_hidden_states_memory = torch.cat((hidden_states_tmp,
hidden_states), dim=1)
value_states = \
self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@ -148,15 +171,10 @@ def yuan_attention_forward(
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
else:
hidden_states = self.lf_gate(hidden_states, before_hidden_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
qk_states = torch.cat([query_states, key_states], dim=-1)
qk_states = qk_states.view(bsz, q_len,
self.num_heads,
int(qk_states.shape[-1]//self.num_heads))
qk_states = self.merged_qk_proj(hidden_states)
(query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None: