Support latest transformer version (#8923)
* Support latest transformer version * fix style
This commit is contained in:
		
							parent
							
								
									25428b22b4
								
							
						
					
					
						commit
						ee98cdd85c
					
				
					 1 changed files with 11 additions and 10 deletions
				
			
		| 
						 | 
				
			
			@ -85,23 +85,23 @@ def llama_attention_forward_4_31(
 | 
			
		|||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
    device = hidden_states.device
 | 
			
		||||
 | 
			
		||||
    if self.pretraining_tp > 1:
 | 
			
		||||
        key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
 | 
			
		||||
    if self.config.pretraining_tp > 1:
 | 
			
		||||
        key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
 | 
			
		||||
        query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim)
 | 
			
		||||
                                                // self.pretraining_tp, dim=0)
 | 
			
		||||
                                                // self.config.pretraining_tp, dim=0)
 | 
			
		||||
        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
 | 
			
		||||
        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
 | 
			
		||||
 | 
			
		||||
        query_states = [F.linear(hidden_states, query_slices[i])
 | 
			
		||||
                        for i in range(self.pretraining_tp)]
 | 
			
		||||
                        for i in range(self.config.pretraining_tp)]
 | 
			
		||||
        query_states = torch.cat(query_states, dim=-1)
 | 
			
		||||
 | 
			
		||||
        key_states = [F.linear(hidden_states, key_slices[i])
 | 
			
		||||
                      for i in range(self.pretraining_tp)]
 | 
			
		||||
                      for i in range(self.config.pretraining_tp)]
 | 
			
		||||
        key_states = torch.cat(key_states, dim=-1)
 | 
			
		||||
 | 
			
		||||
        value_states = [F.linear(hidden_states, value_slices[i])
 | 
			
		||||
                        for i in range(self.pretraining_tp)]
 | 
			
		||||
                        for i in range(self.config.pretraining_tp)]
 | 
			
		||||
        value_states = torch.cat(value_states, dim=-1)
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			@ -194,11 +194,12 @@ def llama_attention_forward_4_31(
 | 
			
		|||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
 | 
			
		||||
    if self.pretraining_tp > 1:
 | 
			
		||||
        attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
 | 
			
		||||
        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
 | 
			
		||||
    if self.config.pretraining_tp > 1:
 | 
			
		||||
        attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
 | 
			
		||||
        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp,
 | 
			
		||||
                                                 dim=1)
 | 
			
		||||
        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i])
 | 
			
		||||
                           for i in range(self.pretraining_tp)])
 | 
			
		||||
                           for i in range(self.config.pretraining_tp)])
 | 
			
		||||
    else:
 | 
			
		||||
        attn_output = self.o_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue