optimize yuan 2.0 performance (#10244)
This commit is contained in:
parent
6c74b99a28
commit
a47989c860
2 changed files with 57 additions and 10 deletions
|
|
@ -539,6 +539,31 @@ def _optimize_pre(model):
|
|||
if model.lm_head.weight.data.device != "meta":
|
||||
norm_weight = nn.functional.normalize(lm_head_weight_data)
|
||||
model.lm_head.weight.data = norm_weight
|
||||
# for yuan 2.0
|
||||
if model.config.model_type == "yuan":
|
||||
def merge_qk_proj_func(module):
|
||||
if "YuanAttention" in module.__class__.__name__:
|
||||
q_weight = module.q_proj.weight.data
|
||||
k_weight = module.k_proj.weight.data
|
||||
num_heads = module.num_heads
|
||||
head_dim = module.head_dim
|
||||
hidden_size = module.hidden_size
|
||||
|
||||
merged_qk_proj = torch.nn.Linear(0, 0, False)
|
||||
weight = torch.cat([
|
||||
q_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
|
||||
k_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
|
||||
q_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
|
||||
k_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
|
||||
], dim=0).view(num_heads * head_dim * 2, hidden_size)
|
||||
merged_qk_proj.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
merged_qk_proj.in_features = hidden_size
|
||||
merged_qk_proj.out_features = num_heads * head_dim * 2
|
||||
module.merged_qk_proj = merged_qk_proj
|
||||
|
||||
del module.q_proj
|
||||
del module.k_proj
|
||||
model.apply(merge_qk_proj_func)
|
||||
return model
|
||||
|
||||
|
||||
|
|
@ -1158,6 +1183,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.yuan import yuan_attention_forward
|
||||
from bigdl.llm.transformers.models.yuan import yuan_mlp_forward
|
||||
from bigdl.llm.transformers.models.yuan import yuan_localized_filtering_forward
|
||||
convert_forward(model,
|
||||
module.YuanAttention,
|
||||
yuan_attention_forward
|
||||
|
|
@ -1166,4 +1192,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
module.YuanMLP,
|
||||
yuan_mlp_forward
|
||||
)
|
||||
convert_forward(model,
|
||||
module.LocalizedFiltering,
|
||||
yuan_localized_filtering_forward)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -50,6 +50,29 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
|
|||
return use_fuse_rope
|
||||
|
||||
|
||||
def yuan_localized_filtering_forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
before_hidden_states: torch.Tensor,
|
||||
):
|
||||
if self.conv1.weight.dtype != torch.half:
|
||||
self.half()
|
||||
|
||||
inputs = inputs.half()
|
||||
if before_hidden_states is not None:
|
||||
before_hidden_states = before_hidden_states.half()
|
||||
|
||||
invalidInputError(self.lf_conv2d_num_pad == 1, "padding must be 1")
|
||||
if self.training:
|
||||
lf_output = self._train_forward(inputs)
|
||||
else:
|
||||
lf_output = self._inference_forward(inputs, before_hidden_states)
|
||||
|
||||
lf_output = lf_output.to(inputs.dtype)
|
||||
|
||||
return lf_output
|
||||
|
||||
|
||||
def yuan_mlp_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
|
@ -132,8 +155,8 @@ def yuan_attention_forward(
|
|||
inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :]
|
||||
else:
|
||||
hidden_states_tmp = before_hidden_states[:, -1:, :]
|
||||
inference_hidden_states_memory = \
|
||||
copy.deepcopy(torch.cat((hidden_states_tmp, hidden_states), dim=1))
|
||||
inference_hidden_states_memory = torch.cat((hidden_states_tmp,
|
||||
hidden_states), dim=1)
|
||||
|
||||
value_states = \
|
||||
self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
|
@ -148,15 +171,10 @@ def yuan_attention_forward(
|
|||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
else:
|
||||
hidden_states = self.lf_gate(hidden_states, before_hidden_states)
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
qk_states = torch.cat([query_states, key_states], dim=-1)
|
||||
qk_states = qk_states.view(bsz, q_len,
|
||||
self.num_heads,
|
||||
int(qk_states.shape[-1]//self.num_heads))
|
||||
qk_states = self.merged_qk_proj(hidden_states)
|
||||
(query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue