Support latest transformer version (#8923)
* Support latest transformer version * fix style
This commit is contained in:
		
							parent
							
								
									25428b22b4
								
							
						
					
					
						commit
						ee98cdd85c
					
				
					 1 changed files with 11 additions and 10 deletions
				
			
		| 
						 | 
					@ -85,23 +85,23 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
    bsz, q_len, _ = hidden_states.size()
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
    device = hidden_states.device
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if self.pretraining_tp > 1:
 | 
					    if self.config.pretraining_tp > 1:
 | 
				
			||||||
        key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
 | 
					        key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
 | 
				
			||||||
        query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim)
 | 
					        query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim)
 | 
				
			||||||
                                                // self.pretraining_tp, dim=0)
 | 
					                                                // self.config.pretraining_tp, dim=0)
 | 
				
			||||||
        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
 | 
					        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
 | 
				
			||||||
        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
 | 
					        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        query_states = [F.linear(hidden_states, query_slices[i])
 | 
					        query_states = [F.linear(hidden_states, query_slices[i])
 | 
				
			||||||
                        for i in range(self.pretraining_tp)]
 | 
					                        for i in range(self.config.pretraining_tp)]
 | 
				
			||||||
        query_states = torch.cat(query_states, dim=-1)
 | 
					        query_states = torch.cat(query_states, dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        key_states = [F.linear(hidden_states, key_slices[i])
 | 
					        key_states = [F.linear(hidden_states, key_slices[i])
 | 
				
			||||||
                      for i in range(self.pretraining_tp)]
 | 
					                      for i in range(self.config.pretraining_tp)]
 | 
				
			||||||
        key_states = torch.cat(key_states, dim=-1)
 | 
					        key_states = torch.cat(key_states, dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        value_states = [F.linear(hidden_states, value_slices[i])
 | 
					        value_states = [F.linear(hidden_states, value_slices[i])
 | 
				
			||||||
                        for i in range(self.pretraining_tp)]
 | 
					                        for i in range(self.config.pretraining_tp)]
 | 
				
			||||||
        value_states = torch.cat(value_states, dim=-1)
 | 
					        value_states = torch.cat(value_states, dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
| 
						 | 
					@ -194,11 +194,12 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
					    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
				
			||||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
					    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if self.pretraining_tp > 1:
 | 
					    if self.config.pretraining_tp > 1:
 | 
				
			||||||
        attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
 | 
					        attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
 | 
				
			||||||
        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
 | 
					        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp,
 | 
				
			||||||
 | 
					                                                 dim=1)
 | 
				
			||||||
        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i])
 | 
					        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i])
 | 
				
			||||||
                           for i in range(self.pretraining_tp)])
 | 
					                           for i in range(self.config.pretraining_tp)])
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        attn_output = self.o_proj(attn_output)
 | 
					        attn_output = self.o_proj(attn_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue