use new rms norm (#10384)

This commit is contained in:
Yishuo Wang 2024-03-12 17:29:51 +08:00 committed by GitHub
parent 0ded0b4b13
commit 741c2bf1df
4 changed files with 24 additions and 40 deletions

View file

@ -50,16 +50,12 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def baichuan_13b_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import linear_q4_0
result = linear_q4_0.fused_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.epsilon)
# if nelement == 0, means fused norm failed, go back to python implement.
if result.nelement != 0:
# We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
result = result.clone()
return result
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = linear_q4_0.rms_norm(self.weight, x_2d, self.epsilon)
if 1 < x_2d.size(0) <= 64: # may use XMX, need copy
output = output.clone()
return output.reshape(hidden_states.shape)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)

View file

@ -97,16 +97,12 @@ def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Ten
def chatglm_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import linear_q4_0
result = linear_q4_0.fused_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.eps)
# if nelement == 0, means fused norm failed, go back to python implement.
if result.nelement != 0:
# We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
result = result.clone()
return result
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = linear_q4_0.rms_norm(self.weight, x_2d, self.eps)
if 1 < x_2d.size(0) <= 64: # may use XMX, need copy
output = output.clone()
return output.reshape(hidden_states.shape)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)

View file

@ -82,16 +82,12 @@ def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
def gemma_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import linear_q4_0
result = linear_q4_0.fused_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight + 1,
None,
self.eps)
# if nelement == 0, means fused norm failed, go back to python implement.
if result.nelement != 0:
# We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
result = result.clone()
return result
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = linear_q4_0.rms_norm(self.weight + 1, x_2d, self.eps)
if 1 < x_2d.size(0) <= 64: # may use XMX, need copy
output = output.clone()
return output.reshape(hidden_states.shape)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)

View file

@ -123,16 +123,12 @@ def llama_model_forward_4_36(
def llama_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import linear_q4_0
result = linear_q4_0.fused_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.variance_epsilon)
# if nelement == 0, means fused norm failed, go back to python implement.
if result.nelement != 0:
# We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
result = result.clone()
return result
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = linear_q4_0.rms_norm(self.weight, x_2d, self.variance_epsilon)
if 1 < x_2d.size(0) <= 64: # may use XMX, need copy
output = output.clone()
return output.reshape(hidden_states.shape)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)