optimize yuan 2.0 performance (#10244)
This commit is contained in:
		
							parent
							
								
									6c74b99a28
								
							
						
					
					
						commit
						a47989c860
					
				
					 2 changed files with 57 additions and 10 deletions
				
			
		| 
						 | 
				
			
			@ -539,6 +539,31 @@ def _optimize_pre(model):
 | 
			
		|||
            if model.lm_head.weight.data.device != "meta":
 | 
			
		||||
                norm_weight = nn.functional.normalize(lm_head_weight_data)
 | 
			
		||||
                model.lm_head.weight.data = norm_weight
 | 
			
		||||
    # for yuan 2.0
 | 
			
		||||
    if model.config.model_type == "yuan":
 | 
			
		||||
        def merge_qk_proj_func(module):
 | 
			
		||||
            if "YuanAttention" in module.__class__.__name__:
 | 
			
		||||
                q_weight = module.q_proj.weight.data
 | 
			
		||||
                k_weight = module.k_proj.weight.data
 | 
			
		||||
                num_heads = module.num_heads
 | 
			
		||||
                head_dim = module.head_dim
 | 
			
		||||
                hidden_size = module.hidden_size
 | 
			
		||||
 | 
			
		||||
                merged_qk_proj = torch.nn.Linear(0, 0, False)
 | 
			
		||||
                weight = torch.cat([
 | 
			
		||||
                    q_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
 | 
			
		||||
                    k_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
 | 
			
		||||
                    q_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
 | 
			
		||||
                    k_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
 | 
			
		||||
                ], dim=0).view(num_heads * head_dim * 2, hidden_size)
 | 
			
		||||
                merged_qk_proj.weight = torch.nn.Parameter(weight, requires_grad=False)
 | 
			
		||||
                merged_qk_proj.in_features = hidden_size
 | 
			
		||||
                merged_qk_proj.out_features = num_heads * head_dim * 2
 | 
			
		||||
                module.merged_qk_proj = merged_qk_proj
 | 
			
		||||
 | 
			
		||||
                del module.q_proj
 | 
			
		||||
                del module.k_proj
 | 
			
		||||
        model.apply(merge_qk_proj_func)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1158,6 +1183,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from bigdl.llm.transformers.models.yuan import yuan_attention_forward
 | 
			
		||||
        from bigdl.llm.transformers.models.yuan import yuan_mlp_forward
 | 
			
		||||
        from bigdl.llm.transformers.models.yuan import yuan_localized_filtering_forward
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.YuanAttention,
 | 
			
		||||
                        yuan_attention_forward
 | 
			
		||||
| 
						 | 
				
			
			@ -1166,4 +1192,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
                        module.YuanMLP,
 | 
			
		||||
                        yuan_mlp_forward
 | 
			
		||||
                        )
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.LocalizedFiltering,
 | 
			
		||||
                        yuan_localized_filtering_forward)
 | 
			
		||||
    return model
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,6 +50,29 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
 | 
			
		|||
    return use_fuse_rope
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def yuan_localized_filtering_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    inputs: torch.Tensor,
 | 
			
		||||
    before_hidden_states: torch.Tensor,
 | 
			
		||||
):
 | 
			
		||||
    if self.conv1.weight.dtype != torch.half:
 | 
			
		||||
        self.half()
 | 
			
		||||
 | 
			
		||||
    inputs = inputs.half()
 | 
			
		||||
    if before_hidden_states is not None:
 | 
			
		||||
        before_hidden_states = before_hidden_states.half()
 | 
			
		||||
 | 
			
		||||
    invalidInputError(self.lf_conv2d_num_pad == 1, "padding must be 1")
 | 
			
		||||
    if self.training:
 | 
			
		||||
        lf_output = self._train_forward(inputs)
 | 
			
		||||
    else:
 | 
			
		||||
        lf_output = self._inference_forward(inputs, before_hidden_states)
 | 
			
		||||
 | 
			
		||||
    lf_output = lf_output.to(inputs.dtype)
 | 
			
		||||
 | 
			
		||||
    return lf_output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def yuan_mlp_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    x: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			@ -132,8 +155,8 @@ def yuan_attention_forward(
 | 
			
		|||
                inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :]
 | 
			
		||||
        else:
 | 
			
		||||
            hidden_states_tmp = before_hidden_states[:, -1:, :]
 | 
			
		||||
            inference_hidden_states_memory = \
 | 
			
		||||
                copy.deepcopy(torch.cat((hidden_states_tmp, hidden_states), dim=1))
 | 
			
		||||
            inference_hidden_states_memory = torch.cat((hidden_states_tmp,
 | 
			
		||||
                                                        hidden_states), dim=1)
 | 
			
		||||
 | 
			
		||||
    value_states = \
 | 
			
		||||
        self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
| 
						 | 
				
			
			@ -148,15 +171,10 @@ def yuan_attention_forward(
 | 
			
		|||
        key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    else:
 | 
			
		||||
        hidden_states = self.lf_gate(hidden_states, before_hidden_states)
 | 
			
		||||
        query_states = self.q_proj(hidden_states)
 | 
			
		||||
        key_states = self.k_proj(hidden_states)
 | 
			
		||||
        qk_states = torch.cat([query_states, key_states], dim=-1)
 | 
			
		||||
        qk_states = qk_states.view(bsz, q_len,
 | 
			
		||||
                                   self.num_heads,
 | 
			
		||||
                                   int(qk_states.shape[-1]//self.num_heads))
 | 
			
		||||
        qk_states = self.merged_qk_proj(hidden_states)
 | 
			
		||||
        (query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
 | 
			
		||||
        query_states = query_states.transpose(1, 2)
 | 
			
		||||
        key_states = key_states.transpose(1, 2)
 | 
			
		||||
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
        key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue