copy fused rms norm's reuslt to avoid <unk> (#9909)
This commit is contained in:
		
							parent
							
								
									05ea0ecd70
								
							
						
					
					
						commit
						dee32f7d15
					
				
					 3 changed files with 6 additions and 0 deletions
				
			
		| 
						 | 
					@ -54,6 +54,8 @@ def baichuan_13b_rms_norm_forward(self, hidden_states):
 | 
				
			||||||
                                            self.epsilon)
 | 
					                                            self.epsilon)
 | 
				
			||||||
        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
					        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
				
			||||||
        if result.nelement != 0:
 | 
					        if result.nelement != 0:
 | 
				
			||||||
 | 
					            # We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
 | 
				
			||||||
 | 
					            result = result.clone()
 | 
				
			||||||
            return result
 | 
					            return result
 | 
				
			||||||
    input_dtype = hidden_states.dtype
 | 
					    input_dtype = hidden_states.dtype
 | 
				
			||||||
    hidden_states = hidden_states.to(torch.float32)
 | 
					    hidden_states = hidden_states.to(torch.float32)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -88,6 +88,8 @@ def chatglm_rms_norm_forward(self, hidden_states):
 | 
				
			||||||
                                            self.eps)
 | 
					                                            self.eps)
 | 
				
			||||||
        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
					        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
				
			||||||
        if result.nelement != 0:
 | 
					        if result.nelement != 0:
 | 
				
			||||||
 | 
					            # We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
 | 
				
			||||||
 | 
					            result = result.clone()
 | 
				
			||||||
            return result
 | 
					            return result
 | 
				
			||||||
    input_dtype = hidden_states.dtype
 | 
					    input_dtype = hidden_states.dtype
 | 
				
			||||||
    hidden_states = hidden_states.to(torch.float32)
 | 
					    hidden_states = hidden_states.to(torch.float32)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -91,6 +91,8 @@ def llama_rms_norm_forward(self, hidden_states):
 | 
				
			||||||
                                            self.variance_epsilon)
 | 
					                                            self.variance_epsilon)
 | 
				
			||||||
        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
					        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
				
			||||||
        if result.nelement != 0:
 | 
					        if result.nelement != 0:
 | 
				
			||||||
 | 
					            # We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
 | 
				
			||||||
 | 
					            result = result.clone()
 | 
				
			||||||
            return result
 | 
					            return result
 | 
				
			||||||
    input_dtype = hidden_states.dtype
 | 
					    input_dtype = hidden_states.dtype
 | 
				
			||||||
    hidden_states = hidden_states.to(torch.float32)
 | 
					    hidden_states = hidden_states.to(torch.float32)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue