handle empty fused norm result (#9688)
* handle empty fused norm result * remove fast_rms_norm * fix style
This commit is contained in:
		
							parent
							
								
									a5c481fedd
								
							
						
					
					
						commit
						320110d158
					
				
					 4 changed files with 51 additions and 69 deletions
				
			
		| 
						 | 
				
			
			@ -49,26 +49,20 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		|||
 | 
			
		||||
def baichuan_13b_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        if get_ipex_version() <= "2.0.110+xpu":
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
            hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
 | 
			
		||||
                                                       [self.weight.size(0)],
 | 
			
		||||
                                                       self.weight,
 | 
			
		||||
                                                       None,
 | 
			
		||||
                                                       self.epsilon)
 | 
			
		||||
        else:
 | 
			
		||||
            hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
 | 
			
		||||
                                                               [self.weight.size(0)],
 | 
			
		||||
                                                               self.weight,
 | 
			
		||||
                                                               None,
 | 
			
		||||
                                                               self.epsilon)
 | 
			
		||||
        return hidden_states
 | 
			
		||||
    else:
 | 
			
		||||
        input_dtype = hidden_states.dtype
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
        hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
 | 
			
		||||
        return self.weight * hidden_states.to(input_dtype)
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        result = linear_q4_0.fused_rms_norm(hidden_states,
 | 
			
		||||
                                            [self.weight.size(0)],
 | 
			
		||||
                                            self.weight,
 | 
			
		||||
                                            None,
 | 
			
		||||
                                            self.epsilon)
 | 
			
		||||
        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
			
		||||
        if result.nelement != 0:
 | 
			
		||||
            return result
 | 
			
		||||
    input_dtype = hidden_states.dtype
 | 
			
		||||
    hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
    hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
 | 
			
		||||
    return self.weight * hidden_states.to(input_dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_mlp_forward(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -65,14 +65,15 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
 | 
			
		|||
def bloom_layer_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        hidden_states = linear_q4_0.fused_layer_norm(hidden_states,
 | 
			
		||||
                                                     [self.weight.size(0)],
 | 
			
		||||
                                                     self.weight,
 | 
			
		||||
                                                     self.bias,
 | 
			
		||||
                                                     self.eps)
 | 
			
		||||
        return hidden_states
 | 
			
		||||
    else:
 | 
			
		||||
        return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
 | 
			
		||||
        result = linear_q4_0.fused_layer_norm(hidden_states,
 | 
			
		||||
                                              [self.weight.size(0)],
 | 
			
		||||
                                              self.weight,
 | 
			
		||||
                                              self.bias,
 | 
			
		||||
                                              self.eps)
 | 
			
		||||
        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
			
		||||
        if result.nelement != 0:
 | 
			
		||||
            return result
 | 
			
		||||
    return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bloom_attention_forward(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -78,27 +78,20 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t
 | 
			
		|||
 | 
			
		||||
def chatglm_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        if get_ipex_version() <= "2.0.110+xpu":
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
            hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
 | 
			
		||||
                                                       [self.weight.size(0)],
 | 
			
		||||
                                                       self.weight,
 | 
			
		||||
                                                       None,
 | 
			
		||||
                                                       self.eps)
 | 
			
		||||
        else:
 | 
			
		||||
            # for ipex >= 2.1
 | 
			
		||||
            hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
 | 
			
		||||
                                                               [self.weight.size(0)],
 | 
			
		||||
                                                               self.weight,
 | 
			
		||||
                                                               None,  # bias
 | 
			
		||||
                                                               self.eps)
 | 
			
		||||
        return hidden_states
 | 
			
		||||
    else:
 | 
			
		||||
        input_dtype = hidden_states.dtype
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
 | 
			
		||||
        return self.weight * hidden_states.to(input_dtype)
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        result = linear_q4_0.fused_rms_norm(hidden_states,
 | 
			
		||||
                                            [self.weight.size(0)],
 | 
			
		||||
                                            self.weight,
 | 
			
		||||
                                            None,
 | 
			
		||||
                                            self.eps)
 | 
			
		||||
        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
			
		||||
        if result.nelement != 0:
 | 
			
		||||
            return result
 | 
			
		||||
    input_dtype = hidden_states.dtype
 | 
			
		||||
    hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
    hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
 | 
			
		||||
    return self.weight * hidden_states.to(input_dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm2_model_forward(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -75,26 +75,20 @@ def get_ipex_version():
 | 
			
		|||
 | 
			
		||||
def llama_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        if get_ipex_version() <= "2.0.110+xpu":
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
            hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
 | 
			
		||||
                                                       [self.weight.size(0)],
 | 
			
		||||
                                                       self.weight,
 | 
			
		||||
                                                       None,
 | 
			
		||||
                                                       self.variance_epsilon)
 | 
			
		||||
        else:
 | 
			
		||||
            hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
 | 
			
		||||
                                                               [self.weight.size(0)],
 | 
			
		||||
                                                               self.weight,
 | 
			
		||||
                                                               None,
 | 
			
		||||
                                                               self.variance_epsilon)
 | 
			
		||||
        return hidden_states
 | 
			
		||||
    else:
 | 
			
		||||
        input_dtype = hidden_states.dtype
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 | 
			
		||||
        return self.weight * hidden_states.to(input_dtype)
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        result = linear_q4_0.fused_rms_norm(hidden_states,
 | 
			
		||||
                                            [self.weight.size(0)],
 | 
			
		||||
                                            self.weight,
 | 
			
		||||
                                            None,
 | 
			
		||||
                                            self.variance_epsilon)
 | 
			
		||||
        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
			
		||||
        if result.nelement != 0:
 | 
			
		||||
            return result
 | 
			
		||||
    input_dtype = hidden_states.dtype
 | 
			
		||||
    hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 | 
			
		||||
    return self.weight * hidden_states.to(input_dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_attention_forward_4_31(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue