[LLM] Split merged_qk to separated q/k linear (#10299)
* modify merge_qk_linear to separated q/k linear * update
This commit is contained in:
		
							parent
							
								
									f4d7dbcde2
								
							
						
					
					
						commit
						0ab40917fb
					
				
					 2 changed files with 22 additions and 12 deletions
				
			
		| 
						 | 
				
			
			@ -555,17 +555,27 @@ def _optimize_pre(model):
 | 
			
		|||
                head_dim = module.head_dim
 | 
			
		||||
                hidden_size = module.hidden_size
 | 
			
		||||
 | 
			
		||||
                merged_qk_proj = torch.nn.Linear(0, 0, False)
 | 
			
		||||
                weight = torch.cat([
 | 
			
		||||
                weight_q = torch.cat([
 | 
			
		||||
                    q_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
 | 
			
		||||
                    k_weight.view(num_heads, head_dim, hidden_size)[0::2, :, :],
 | 
			
		||||
                ], dim=0).view(num_heads * head_dim, hidden_size)
 | 
			
		||||
 | 
			
		||||
                weight_k = torch.cat([
 | 
			
		||||
                    q_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
 | 
			
		||||
                    k_weight.view(num_heads, head_dim, hidden_size)[1::2, :, :],
 | 
			
		||||
                ], dim=0).view(num_heads * head_dim * 2, hidden_size)
 | 
			
		||||
                merged_qk_proj.weight = torch.nn.Parameter(weight, requires_grad=False)
 | 
			
		||||
                merged_qk_proj.in_features = hidden_size
 | 
			
		||||
                merged_qk_proj.out_features = num_heads * head_dim * 2
 | 
			
		||||
                module.merged_qk_proj = merged_qk_proj
 | 
			
		||||
                ], dim=0).view(num_heads * head_dim, hidden_size)
 | 
			
		||||
 | 
			
		||||
                merged_q_proj = torch.nn.Linear(0, 0, False)
 | 
			
		||||
                merged_q_proj.weight = torch.nn.Parameter(weight_q, requires_grad=False)
 | 
			
		||||
                merged_q_proj.in_features = hidden_size
 | 
			
		||||
                merged_q_proj.out_features = num_heads * head_dim
 | 
			
		||||
                module.merged_q_proj = merged_q_proj
 | 
			
		||||
 | 
			
		||||
                merged_k_proj = torch.nn.Linear(0, 0, False)
 | 
			
		||||
                merged_k_proj.weight = torch.nn.Parameter(weight_k, requires_grad=False)
 | 
			
		||||
                merged_k_proj.in_features = hidden_size
 | 
			
		||||
                merged_k_proj.out_features = num_heads * head_dim
 | 
			
		||||
                module.merged_k_proj = merged_k_proj
 | 
			
		||||
 | 
			
		||||
                del module.q_proj
 | 
			
		||||
                del module.k_proj
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -147,7 +147,7 @@ def yuan_attention_forward(
 | 
			
		|||
    output_attentions: bool = False,
 | 
			
		||||
    use_cache: bool = False,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
    if use_quantize_kv_cache(self.merged_qk_proj, hidden_states):
 | 
			
		||||
    if use_quantize_kv_cache(self.merged_q_proj, hidden_states):
 | 
			
		||||
        forward_function = yuan_attention_forward_quantized
 | 
			
		||||
    else:
 | 
			
		||||
        forward_function = yuan_attention_forward_origin
 | 
			
		||||
| 
						 | 
				
			
			@ -206,8 +206,8 @@ def yuan_attention_forward_quantized(
 | 
			
		|||
    else:
 | 
			
		||||
        hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
 | 
			
		||||
                                                         this_hidden_states, hidden_states.dtype)
 | 
			
		||||
    qk_states = self.merged_qk_proj(hidden_states)
 | 
			
		||||
    (query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
 | 
			
		||||
    query_states = self.merged_q_proj(hidden_states)
 | 
			
		||||
    key_states = self.merged_k_proj(hidden_states)
 | 
			
		||||
    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -360,8 +360,8 @@ def yuan_attention_forward_origin(
 | 
			
		|||
    else:
 | 
			
		||||
        hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
 | 
			
		||||
                                                         this_hidden_states, hidden_states.dtype)
 | 
			
		||||
    qk_states = self.merged_qk_proj(hidden_states)
 | 
			
		||||
    (query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
 | 
			
		||||
    query_states = self.merged_q_proj(hidden_states)
 | 
			
		||||
    key_states = self.merged_k_proj(hidden_states)
 | 
			
		||||
    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue