copy fused rms norm's reuslt to avoid <unk> (#9909)

This commit is contained in:
Xin Qiu 2024-01-16 16:54:08 +08:00 committed by GitHub
parent 05ea0ecd70
commit dee32f7d15
3 changed files with 6 additions and 0 deletions

View file

@ -54,6 +54,8 @@ def baichuan_13b_rms_norm_forward(self, hidden_states):
self.epsilon) self.epsilon)
# if nelement == 0, means fused norm failed, go back to python implement. # if nelement == 0, means fused norm failed, go back to python implement.
if result.nelement != 0: if result.nelement != 0:
# We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
result = result.clone()
return result return result
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)

View file

@ -88,6 +88,8 @@ def chatglm_rms_norm_forward(self, hidden_states):
self.eps) self.eps)
# if nelement == 0, means fused norm failed, go back to python implement. # if nelement == 0, means fused norm failed, go back to python implement.
if result.nelement != 0: if result.nelement != 0:
# We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
result = result.clone()
return result return result
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)

View file

@ -91,6 +91,8 @@ def llama_rms_norm_forward(self, hidden_states):
self.variance_epsilon) self.variance_epsilon)
# if nelement == 0, means fused norm failed, go back to python implement. # if nelement == 0, means fused norm failed, go back to python implement.
if result.nelement != 0: if result.nelement != 0:
# We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
result = result.clone()
return result return result
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)