use new rms norm (#10384)
This commit is contained in:
		
							parent
							
								
									0ded0b4b13
								
							
						
					
					
						commit
						741c2bf1df
					
				
					 4 changed files with 24 additions and 40 deletions
				
			
		| 
						 | 
				
			
			@ -50,16 +50,12 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		|||
def baichuan_13b_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        result = linear_q4_0.fused_rms_norm(hidden_states,
 | 
			
		||||
                                            [self.weight.size(0)],
 | 
			
		||||
                                            self.weight,
 | 
			
		||||
                                            None,
 | 
			
		||||
                                            self.epsilon)
 | 
			
		||||
        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
			
		||||
        if result.nelement != 0:
 | 
			
		||||
            # We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
 | 
			
		||||
            result = result.clone()
 | 
			
		||||
            return result
 | 
			
		||||
        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
			
		||||
        output = linear_q4_0.rms_norm(self.weight, x_2d, self.epsilon)
 | 
			
		||||
        if 1 < x_2d.size(0) <= 64:   # may use XMX, need copy
 | 
			
		||||
            output = output.clone()
 | 
			
		||||
        return output.reshape(hidden_states.shape)
 | 
			
		||||
 | 
			
		||||
    input_dtype = hidden_states.dtype
 | 
			
		||||
    hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -97,16 +97,12 @@ def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Ten
 | 
			
		|||
def chatglm_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        result = linear_q4_0.fused_rms_norm(hidden_states,
 | 
			
		||||
                                            [self.weight.size(0)],
 | 
			
		||||
                                            self.weight,
 | 
			
		||||
                                            None,
 | 
			
		||||
                                            self.eps)
 | 
			
		||||
        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
			
		||||
        if result.nelement != 0:
 | 
			
		||||
            # We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
 | 
			
		||||
            result = result.clone()
 | 
			
		||||
            return result
 | 
			
		||||
        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
			
		||||
        output = linear_q4_0.rms_norm(self.weight, x_2d, self.eps)
 | 
			
		||||
        if 1 < x_2d.size(0) <= 64:   # may use XMX, need copy
 | 
			
		||||
            output = output.clone()
 | 
			
		||||
        return output.reshape(hidden_states.shape)
 | 
			
		||||
 | 
			
		||||
    input_dtype = hidden_states.dtype
 | 
			
		||||
    hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -82,16 +82,12 @@ def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
 | 
			
		|||
def gemma_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        result = linear_q4_0.fused_rms_norm(hidden_states,
 | 
			
		||||
                                            [self.weight.size(0)],
 | 
			
		||||
                                            self.weight + 1,
 | 
			
		||||
                                            None,
 | 
			
		||||
                                            self.eps)
 | 
			
		||||
        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
			
		||||
        if result.nelement != 0:
 | 
			
		||||
            # We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
 | 
			
		||||
            result = result.clone()
 | 
			
		||||
            return result
 | 
			
		||||
        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
			
		||||
        output = linear_q4_0.rms_norm(self.weight + 1, x_2d, self.eps)
 | 
			
		||||
        if 1 < x_2d.size(0) <= 64:   # may use XMX, need copy
 | 
			
		||||
            output = output.clone()
 | 
			
		||||
        return output.reshape(hidden_states.shape)
 | 
			
		||||
 | 
			
		||||
    input_dtype = hidden_states.dtype
 | 
			
		||||
    hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -123,16 +123,12 @@ def llama_model_forward_4_36(
 | 
			
		|||
def llama_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        result = linear_q4_0.fused_rms_norm(hidden_states,
 | 
			
		||||
                                            [self.weight.size(0)],
 | 
			
		||||
                                            self.weight,
 | 
			
		||||
                                            None,
 | 
			
		||||
                                            self.variance_epsilon)
 | 
			
		||||
        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
			
		||||
        if result.nelement != 0:
 | 
			
		||||
            # We should copy this result to avoid <unk> by unknown reason on Arc GPUs.
 | 
			
		||||
            result = result.clone()
 | 
			
		||||
            return result
 | 
			
		||||
        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
			
		||||
        output = linear_q4_0.rms_norm(self.weight, x_2d, self.variance_epsilon)
 | 
			
		||||
        if 1 < x_2d.size(0) <= 64:   # may use XMX, need copy
 | 
			
		||||
            output = output.clone()
 | 
			
		||||
        return output.reshape(hidden_states.shape)
 | 
			
		||||
 | 
			
		||||
    input_dtype = hidden_states.dtype
 | 
			
		||||
    hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue