support new baichuan model (#10404)
This commit is contained in:
		
							parent
							
								
									a90e9b6ec2
								
							
						
					
					
						commit
						06a851afa9
					
				
					 2 changed files with 5 additions and 5 deletions
				
			
		| 
						 | 
				
			
			@ -953,7 +953,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
 | 
			
		||||
    elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
 | 
			
		||||
        # baichuan2
 | 
			
		||||
        if model.config.hidden_size == 4096:
 | 
			
		||||
        if model.config.hidden_size in [4096, 2048]:
 | 
			
		||||
            # baichuan2-7B
 | 
			
		||||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
            module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -118,7 +118,7 @@ def baichuan_attention_forward_7b_quantized(
 | 
			
		|||
    device = hidden_states.device
 | 
			
		||||
 | 
			
		||||
    proj = self.W_pack(hidden_states)
 | 
			
		||||
    proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
 | 
			
		||||
    proj = torch.chunk(proj, 3, -1)
 | 
			
		||||
    # batch_size x source_len x hidden_size
 | 
			
		||||
    query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    # batch_size x target_len x head_size
 | 
			
		||||
| 
						 | 
				
			
			@ -176,7 +176,7 @@ def baichuan_attention_forward_7b_quantized(
 | 
			
		|||
                                                        value_states.transpose(-1, -2))
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
 | 
			
		||||
    attn_output = self.o_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
    if not output_attentions:
 | 
			
		||||
| 
						 | 
				
			
			@ -198,7 +198,7 @@ def baichuan_attention_forward_7b_origin(
 | 
			
		|||
    device = hidden_states.device
 | 
			
		||||
 | 
			
		||||
    proj = self.W_pack(hidden_states)
 | 
			
		||||
    proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
 | 
			
		||||
    proj = torch.chunk(proj, 3, -1)
 | 
			
		||||
    # batch_size x source_len x hidden_size
 | 
			
		||||
    query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    # batch_size x target_len x head_size
 | 
			
		||||
| 
						 | 
				
			
			@ -283,7 +283,7 @@ def baichuan_attention_forward_7b_origin(
 | 
			
		|||
        attn_output = torch.matmul(attn_output, value_states)
 | 
			
		||||
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
 | 
			
		||||
    attn_output = self.o_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
    if not output_attentions:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue