remove rms norm copy (#10793)
This commit is contained in:
		
							parent
							
								
									c7235e34a8
								
							
						
					
					
						commit
						08458b4f74
					
				
					 4 changed files with 0 additions and 8 deletions
				
			
		| 
						 | 
					@ -54,8 +54,6 @@ def baichuan_13b_rms_norm_forward(self, hidden_states):
 | 
				
			||||||
        import linear_q4_0
 | 
					        import linear_q4_0
 | 
				
			||||||
        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
					        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
				
			||||||
        output = linear_q4_0.rms_norm(self.weight, x_2d, self.epsilon)
 | 
					        output = linear_q4_0.rms_norm(self.weight, x_2d, self.epsilon)
 | 
				
			||||||
        if 1 < x_2d.size(0) <= 64:   # may use XMX, need copy
 | 
					 | 
				
			||||||
            output = output.clone()
 | 
					 | 
				
			||||||
        return output.reshape(hidden_states.shape)
 | 
					        return output.reshape(hidden_states.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    input_dtype = hidden_states.dtype
 | 
					    input_dtype = hidden_states.dtype
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -101,8 +101,6 @@ def chatglm_rms_norm_forward(self, hidden_states):
 | 
				
			||||||
        import linear_q4_0
 | 
					        import linear_q4_0
 | 
				
			||||||
        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
					        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
				
			||||||
        output = linear_q4_0.rms_norm(self.weight, x_2d, self.eps)
 | 
					        output = linear_q4_0.rms_norm(self.weight, x_2d, self.eps)
 | 
				
			||||||
        if 1 < x_2d.size(0) <= 64:   # may use XMX, need copy
 | 
					 | 
				
			||||||
            output = output.clone()
 | 
					 | 
				
			||||||
        return output.reshape(hidden_states.shape)
 | 
					        return output.reshape(hidden_states.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    input_dtype = hidden_states.dtype
 | 
					    input_dtype = hidden_states.dtype
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -82,8 +82,6 @@ def gemma_rms_norm_forward(self, hidden_states):
 | 
				
			||||||
        import linear_q4_0
 | 
					        import linear_q4_0
 | 
				
			||||||
        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
					        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
				
			||||||
        output = linear_q4_0.rms_norm(self.weight + 1, x_2d, self.eps)
 | 
					        output = linear_q4_0.rms_norm(self.weight + 1, x_2d, self.eps)
 | 
				
			||||||
        if 1 < x_2d.size(0) <= 64:   # may use XMX, need copy
 | 
					 | 
				
			||||||
            output = output.clone()
 | 
					 | 
				
			||||||
        return output.reshape(hidden_states.shape)
 | 
					        return output.reshape(hidden_states.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    input_dtype = hidden_states.dtype
 | 
					    input_dtype = hidden_states.dtype
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -138,8 +138,6 @@ def llama_rms_norm_forward(self, hidden_states):
 | 
				
			||||||
        import linear_q4_0
 | 
					        import linear_q4_0
 | 
				
			||||||
        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
					        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
				
			||||||
        output = linear_q4_0.rms_norm(self.weight, x_2d, self.variance_epsilon)
 | 
					        output = linear_q4_0.rms_norm(self.weight, x_2d, self.variance_epsilon)
 | 
				
			||||||
        if 1 < x_2d.size(0) <= 64:   # may use XMX, need copy
 | 
					 | 
				
			||||||
            output = output.clone()
 | 
					 | 
				
			||||||
        return output.reshape(hidden_states.shape)
 | 
					        return output.reshape(hidden_states.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    input_dtype = hidden_states.dtype
 | 
					    input_dtype = hidden_states.dtype
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue