fix custom kernel registration (#12674)

This commit is contained in:
Yishuo Wang 2025-01-08 17:39:17 +08:00 committed by GitHub
parent a22a8c21bb
commit 5c24276fc4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -20,9 +20,9 @@ import xe_batch
import xe_addons
@torch.library.register_fake("ipex_llm::forward_new")
def _(x, weight, qtype, input_size):
return torch.empty_like(x)
# @torch.library.register_fake("ipex_llm::forward_new")
# def _(x, weight, qtype, input_size):
# return ???
# @torch.library.register_fake("ipex_llm::dequant")
@ -32,32 +32,38 @@ def _(x, weight, qtype, input_size):
@torch.library.register_fake("ipex_llm::mlp_forward_xpu")
def _(x, weight1, weight2, batch_size, state_size, output_size, act_type, qtype):
return torch.empty_like(x)
return torch.empty([batch_size, output_size],
dtype=x.dtype, device=x.device)
# @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4")
# def _(time_decay, time_first, key, value, num_state, den_state, max_state)
# return ???
@torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4")
def _(time_decay, time_first, key, value, num_state, den_state, max_state):
return torch.empty_like(key)
# @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5")
# def _(time_decay, time_first, receptance, key, value, state)
# return ???
@torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5")
def _(time_decay, time_first, receptance, key, value, state):
bsz, n_heads, seq_len, head_dim = key.shape
return torch.empty([bsz, seq_len, n_heads, head_dim],
dtype=key.dtype, device=key.device)
# @torch.library.register_fake("ipex_llm::rwkv_time_shift")
# def _(hidden, shifted, mix):
# return ???
@torch.library.register_fake("ipex_llm::rwkv_time_shift")
def _(hidden, shifted, mix):
bsz, seq_len, hidden_size = hidden.shape
return torch.empty([mix.size(0), bsz, seq_len, hidden_size],
dtype=hidden.dtype, device=hidden.device)
# @torch.library.register_fake("ipex_llm::dequantize_rows")
# def _(x, weight, qtype, state_size, output_size):
# return ???
@torch.library.register_fake("ipex_llm::dequantize_rows")
def _(x, weight, qtype, state_size, output_size):
return torch.empty([x.size(0), x.size(1), state_size],
dtype=torch.float, device=weight.device)
@torch.library.register_fake("ipex_llm::batch_forward")
def _(x, weight, qtype):
return torch.empty_like(x)
# @torch.library.register_fake("ipex_llm::batch_forward")
# def _(x, weight, qtype):
# return ???
@torch.library.register_fake("ipex_llm::sdp")