optimize yuan 2.0 performance (#10244)
This commit is contained in:
parent
6c74b99a28
commit
a47989c860
2 changed files with 57 additions and 10 deletions
|
|
@ -539,6 +539,31 @@ def _optimize_pre(model):
|
||||||
if model.lm_head.weight.data.device != "meta":
|
if model.lm_head.weight.data.device != "meta":
|
||||||
norm_weight = nn.functional.normalize(lm_head_weight_data)
|
norm_weight = nn.functional.normalize(lm_head_weight_data)
|
||||||
model.lm_head.weight.data = norm_weight
|
model.lm_head.weight.data = norm_weight
|
||||||
|
# for yuan 2.0
|
||||||
|
if model.config.model_type == "yuan":
|
||||||
|
def merge_qk_proj_func(module):
|
||||||
|
if "YuanAttention" in module.__class__.__name__:
|
||||||
|
q_weight = module.q_proj.weight.data
|
||||||
|
k_weight = module.k_proj.weight.data
|
||||||
|
num_heads = module.num_heads
|
||||||
|
head_dim = module.head_dim
|
||||||
|
hidden_size = module.hidden_size
|
||||||
|
|
||||||
|
merged_qk_proj = torch.nn.Linear(0, 0, False)
|
||||||
|
weight = torch.cat([
|
||||||
|
q_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
|
||||||
|
k_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
|
||||||
|
q_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
|
||||||
|
k_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
|
||||||
|
], dim=0).view(num_heads * head_dim * 2, hidden_size)
|
||||||
|
merged_qk_proj.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
|
merged_qk_proj.in_features = hidden_size
|
||||||
|
merged_qk_proj.out_features = num_heads * head_dim * 2
|
||||||
|
module.merged_qk_proj = merged_qk_proj
|
||||||
|
|
||||||
|
del module.q_proj
|
||||||
|
del module.k_proj
|
||||||
|
model.apply(merge_qk_proj_func)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1158,6 +1183,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.yuan import yuan_attention_forward
|
from bigdl.llm.transformers.models.yuan import yuan_attention_forward
|
||||||
from bigdl.llm.transformers.models.yuan import yuan_mlp_forward
|
from bigdl.llm.transformers.models.yuan import yuan_mlp_forward
|
||||||
|
from bigdl.llm.transformers.models.yuan import yuan_localized_filtering_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.YuanAttention,
|
module.YuanAttention,
|
||||||
yuan_attention_forward
|
yuan_attention_forward
|
||||||
|
|
@ -1166,4 +1192,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module.YuanMLP,
|
module.YuanMLP,
|
||||||
yuan_mlp_forward
|
yuan_mlp_forward
|
||||||
)
|
)
|
||||||
|
convert_forward(model,
|
||||||
|
module.LocalizedFiltering,
|
||||||
|
yuan_localized_filtering_forward)
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,29 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
|
||||||
return use_fuse_rope
|
return use_fuse_rope
|
||||||
|
|
||||||
|
|
||||||
|
def yuan_localized_filtering_forward(
|
||||||
|
self,
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
before_hidden_states: torch.Tensor,
|
||||||
|
):
|
||||||
|
if self.conv1.weight.dtype != torch.half:
|
||||||
|
self.half()
|
||||||
|
|
||||||
|
inputs = inputs.half()
|
||||||
|
if before_hidden_states is not None:
|
||||||
|
before_hidden_states = before_hidden_states.half()
|
||||||
|
|
||||||
|
invalidInputError(self.lf_conv2d_num_pad == 1, "padding must be 1")
|
||||||
|
if self.training:
|
||||||
|
lf_output = self._train_forward(inputs)
|
||||||
|
else:
|
||||||
|
lf_output = self._inference_forward(inputs, before_hidden_states)
|
||||||
|
|
||||||
|
lf_output = lf_output.to(inputs.dtype)
|
||||||
|
|
||||||
|
return lf_output
|
||||||
|
|
||||||
|
|
||||||
def yuan_mlp_forward(
|
def yuan_mlp_forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
|
@ -132,8 +155,8 @@ def yuan_attention_forward(
|
||||||
inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :]
|
inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :]
|
||||||
else:
|
else:
|
||||||
hidden_states_tmp = before_hidden_states[:, -1:, :]
|
hidden_states_tmp = before_hidden_states[:, -1:, :]
|
||||||
inference_hidden_states_memory = \
|
inference_hidden_states_memory = torch.cat((hidden_states_tmp,
|
||||||
copy.deepcopy(torch.cat((hidden_states_tmp, hidden_states), dim=1))
|
hidden_states), dim=1)
|
||||||
|
|
||||||
value_states = \
|
value_states = \
|
||||||
self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
@ -148,15 +171,10 @@ def yuan_attention_forward(
|
||||||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
else:
|
else:
|
||||||
hidden_states = self.lf_gate(hidden_states, before_hidden_states)
|
hidden_states = self.lf_gate(hidden_states, before_hidden_states)
|
||||||
query_states = self.q_proj(hidden_states)
|
qk_states = self.merged_qk_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
qk_states = torch.cat([query_states, key_states], dim=-1)
|
|
||||||
qk_states = qk_states.view(bsz, q_len,
|
|
||||||
self.num_heads,
|
|
||||||
int(qk_states.shape[-1]//self.num_heads))
|
|
||||||
(query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
|
(query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue