copy fused rms norm's reuslt to avoid <unk> (#9909)
This commit is contained in:
parent
05ea0ecd70
commit
dee32f7d15
3 changed files with 6 additions and 0 deletions
|
|
@ -54,6 +54,8 @@ def baichuan_13b_rms_norm_forward(self, hidden_states):
|
||||||
self.epsilon)
|
self.epsilon)
|
||||||
# if nelement == 0, means fused norm failed, go back to python implement.
|
# if nelement == 0, means fused norm failed, go back to python implement.
|
||||||
if result.nelement != 0:
|
if result.nelement != 0:
|
||||||
|
# We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
|
||||||
|
result = result.clone()
|
||||||
return result
|
return result
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,8 @@ def chatglm_rms_norm_forward(self, hidden_states):
|
||||||
self.eps)
|
self.eps)
|
||||||
# if nelement == 0, means fused norm failed, go back to python implement.
|
# if nelement == 0, means fused norm failed, go back to python implement.
|
||||||
if result.nelement != 0:
|
if result.nelement != 0:
|
||||||
|
# We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
|
||||||
|
result = result.clone()
|
||||||
return result
|
return result
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,8 @@ def llama_rms_norm_forward(self, hidden_states):
|
||||||
self.variance_epsilon)
|
self.variance_epsilon)
|
||||||
# if nelement == 0, means fused norm failed, go back to python implement.
|
# if nelement == 0, means fused norm failed, go back to python implement.
|
||||||
if result.nelement != 0:
|
if result.nelement != 0:
|
||||||
|
# We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
|
||||||
|
result = result.clone()
|
||||||
return result
|
return result
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue