remove rms norm copy (#10793)

This commit is contained in:
Yishuo Wang 2024-04-19 13:57:48 +08:00 committed by GitHub
parent c7235e34a8
commit 08458b4f74
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 0 additions and 8 deletions

View file

@ -54,8 +54,6 @@ def baichuan_13b_rms_norm_forward(self, hidden_states):
import linear_q4_0
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = linear_q4_0.rms_norm(self.weight, x_2d, self.epsilon)
if 1 < x_2d.size(0) <= 64: # may use XMX, need copy
output = output.clone()
return output.reshape(hidden_states.shape)
input_dtype = hidden_states.dtype

View file

@ -101,8 +101,6 @@ def chatglm_rms_norm_forward(self, hidden_states):
import linear_q4_0
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = linear_q4_0.rms_norm(self.weight, x_2d, self.eps)
if 1 < x_2d.size(0) <= 64: # may use XMX, need copy
output = output.clone()
return output.reshape(hidden_states.shape)
input_dtype = hidden_states.dtype

View file

@ -82,8 +82,6 @@ def gemma_rms_norm_forward(self, hidden_states):
import linear_q4_0
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = linear_q4_0.rms_norm(self.weight + 1, x_2d, self.eps)
if 1 < x_2d.size(0) <= 64: # may use XMX, need copy
output = output.clone()
return output.reshape(hidden_states.shape)
input_dtype = hidden_states.dtype

View file

@ -138,8 +138,6 @@ def llama_rms_norm_forward(self, hidden_states):
import linear_q4_0
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = linear_q4_0.rms_norm(self.weight, x_2d, self.variance_epsilon)
if 1 < x_2d.size(0) <= 64: # may use XMX, need copy
output = output.clone()
return output.reshape(hidden_states.shape)
input_dtype = hidden_states.dtype