diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 3ec171ed..e3e93b3b 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -539,6 +539,31 @@ def _optimize_pre(model): if model.lm_head.weight.data.device != "meta": norm_weight = nn.functional.normalize(lm_head_weight_data) model.lm_head.weight.data = norm_weight + # for yuan 2.0 + if model.config.model_type == "yuan": + def merge_qk_proj_func(module): + if "YuanAttention" in module.__class__.__name__: + q_weight = module.q_proj.weight.data + k_weight = module.k_proj.weight.data + num_heads = module.num_heads + head_dim = module.head_dim + hidden_size = module.hidden_size + + merged_qk_proj = torch.nn.Linear(0, 0, False) + weight = torch.cat([ + q_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :], + k_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :], + q_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :], + k_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :], + ], dim=0).view(num_heads * head_dim * 2, hidden_size) + merged_qk_proj.weight = torch.nn.Parameter(weight, requires_grad=False) + merged_qk_proj.in_features = hidden_size + merged_qk_proj.out_features = num_heads * head_dim * 2 + module.merged_qk_proj = merged_qk_proj + + del module.q_proj + del module.k_proj + model.apply(merge_qk_proj_func) return model @@ -1158,6 +1183,7 @@ def _optimize_post(model, lightweight_bmm=False): module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.yuan import yuan_attention_forward from bigdl.llm.transformers.models.yuan import yuan_mlp_forward + from bigdl.llm.transformers.models.yuan import yuan_localized_filtering_forward convert_forward(model, module.YuanAttention, yuan_attention_forward @@ -1166,4 +1192,7 @@ def _optimize_post(model, lightweight_bmm=False): module.YuanMLP, yuan_mlp_forward ) + convert_forward(model, + module.LocalizedFiltering, + yuan_localized_filtering_forward) return model diff --git a/python/llm/src/bigdl/llm/transformers/models/yuan.py b/python/llm/src/bigdl/llm/transformers/models/yuan.py index 2419fa91..a874925f 100644 --- a/python/llm/src/bigdl/llm/transformers/models/yuan.py +++ b/python/llm/src/bigdl/llm/transformers/models/yuan.py @@ -50,6 +50,29 @@ def should_use_fuse_rope(self, hidden_states, position_ids): return use_fuse_rope +def yuan_localized_filtering_forward( + self, + inputs: torch.Tensor, + before_hidden_states: torch.Tensor, +): + if self.conv1.weight.dtype != torch.half: + self.half() + + inputs = inputs.half() + if before_hidden_states is not None: + before_hidden_states = before_hidden_states.half() + + invalidInputError(self.lf_conv2d_num_pad == 1, "padding must be 1") + if self.training: + lf_output = self._train_forward(inputs) + else: + lf_output = self._inference_forward(inputs, before_hidden_states) + + lf_output = lf_output.to(inputs.dtype) + + return lf_output + + def yuan_mlp_forward( self, x: torch.Tensor, @@ -132,8 +155,8 @@ def yuan_attention_forward( inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :] else: hidden_states_tmp = before_hidden_states[:, -1:, :] - inference_hidden_states_memory = \ - copy.deepcopy(torch.cat((hidden_states_tmp, hidden_states), dim=1)) + inference_hidden_states_memory = torch.cat((hidden_states_tmp, + hidden_states), dim=1) value_states = \ self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -148,15 +171,10 @@ def yuan_attention_forward( key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) else: hidden_states = self.lf_gate(hidden_states, before_hidden_states) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - qk_states = torch.cat([query_states, key_states], dim=-1) - qk_states = qk_states.view(bsz, q_len, - self.num_heads, - int(qk_states.shape[-1]//self.num_heads)) + qk_states = self.merged_qk_proj(hidden_states) (query_states, key_states) = torch.chunk(qk_states, 2, dim=-1) - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: