fix custom kernel registration (#12674)
This commit is contained in:
		
							parent
							
								
									a22a8c21bb
								
							
						
					
					
						commit
						5c24276fc4
					
				
					 1 changed files with 25 additions and 19 deletions
				
			
		| 
						 | 
				
			
			@ -20,9 +20,9 @@ import xe_batch
 | 
			
		|||
import xe_addons
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.library.register_fake("ipex_llm::forward_new")
 | 
			
		||||
def _(x, weight, qtype, input_size):
 | 
			
		||||
    return torch.empty_like(x)
 | 
			
		||||
# @torch.library.register_fake("ipex_llm::forward_new")
 | 
			
		||||
# def _(x, weight, qtype, input_size):
 | 
			
		||||
#     return ???
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @torch.library.register_fake("ipex_llm::dequant")
 | 
			
		||||
| 
						 | 
				
			
			@ -32,32 +32,38 @@ def _(x, weight, qtype, input_size):
 | 
			
		|||
 | 
			
		||||
@torch.library.register_fake("ipex_llm::mlp_forward_xpu")
 | 
			
		||||
def _(x, weight1, weight2, batch_size, state_size, output_size, act_type, qtype):
 | 
			
		||||
    return torch.empty_like(x)
 | 
			
		||||
    return torch.empty([batch_size, output_size],
 | 
			
		||||
                       dtype=x.dtype, device=x.device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4")
 | 
			
		||||
# def _(time_decay, time_first, key, value, num_state, den_state, max_state)
 | 
			
		||||
    # return ???
 | 
			
		||||
@torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4")
 | 
			
		||||
def _(time_decay, time_first, key, value, num_state, den_state, max_state):
 | 
			
		||||
    return torch.empty_like(key)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5")
 | 
			
		||||
# def _(time_decay, time_first, receptance, key, value, state)
 | 
			
		||||
    # return ???
 | 
			
		||||
@torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5")
 | 
			
		||||
def _(time_decay, time_first, receptance, key, value, state):
 | 
			
		||||
    bsz, n_heads, seq_len, head_dim = key.shape
 | 
			
		||||
    return torch.empty([bsz, seq_len, n_heads, head_dim],
 | 
			
		||||
                       dtype=key.dtype, device=key.device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @torch.library.register_fake("ipex_llm::rwkv_time_shift")
 | 
			
		||||
# def _(hidden, shifted, mix):
 | 
			
		||||
    # return ???
 | 
			
		||||
@torch.library.register_fake("ipex_llm::rwkv_time_shift")
 | 
			
		||||
def _(hidden, shifted, mix):
 | 
			
		||||
    bsz, seq_len, hidden_size = hidden.shape
 | 
			
		||||
    return torch.empty([mix.size(0), bsz, seq_len, hidden_size],
 | 
			
		||||
                       dtype=hidden.dtype, device=hidden.device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @torch.library.register_fake("ipex_llm::dequantize_rows")
 | 
			
		||||
# def _(x, weight, qtype, state_size, output_size):
 | 
			
		||||
    # return ???
 | 
			
		||||
@torch.library.register_fake("ipex_llm::dequantize_rows")
 | 
			
		||||
def _(x, weight, qtype, state_size, output_size):
 | 
			
		||||
    return torch.empty([x.size(0), x.size(1), state_size],
 | 
			
		||||
                       dtype=torch.float, device=weight.device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.library.register_fake("ipex_llm::batch_forward")
 | 
			
		||||
def _(x, weight, qtype):
 | 
			
		||||
    return torch.empty_like(x)
 | 
			
		||||
# @torch.library.register_fake("ipex_llm::batch_forward")
 | 
			
		||||
# def _(x, weight, qtype):
 | 
			
		||||
#     return ???
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.library.register_fake("ipex_llm::sdp")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue