fix custom kernel registration (#12674)
This commit is contained in:
parent
a22a8c21bb
commit
5c24276fc4
1 changed files with 25 additions and 19 deletions
|
|
@ -20,9 +20,9 @@ import xe_batch
|
|||
import xe_addons
|
||||
|
||||
|
||||
@torch.library.register_fake("ipex_llm::forward_new")
|
||||
def _(x, weight, qtype, input_size):
|
||||
return torch.empty_like(x)
|
||||
# @torch.library.register_fake("ipex_llm::forward_new")
|
||||
# def _(x, weight, qtype, input_size):
|
||||
# return ???
|
||||
|
||||
|
||||
# @torch.library.register_fake("ipex_llm::dequant")
|
||||
|
|
@ -32,34 +32,40 @@ def _(x, weight, qtype, input_size):
|
|||
|
||||
@torch.library.register_fake("ipex_llm::mlp_forward_xpu")
|
||||
def _(x, weight1, weight2, batch_size, state_size, output_size, act_type, qtype):
|
||||
return torch.empty_like(x)
|
||||
return torch.empty([batch_size, output_size],
|
||||
dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
# @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4")
|
||||
# def _(time_decay, time_first, key, value, num_state, den_state, max_state)
|
||||
@torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4")
|
||||
def _(time_decay, time_first, key, value, num_state, den_state, max_state):
|
||||
return torch.empty_like(key)
|
||||
|
||||
|
||||
@torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5")
|
||||
def _(time_decay, time_first, receptance, key, value, state):
|
||||
bsz, n_heads, seq_len, head_dim = key.shape
|
||||
return torch.empty([bsz, seq_len, n_heads, head_dim],
|
||||
dtype=key.dtype, device=key.device)
|
||||
|
||||
|
||||
@torch.library.register_fake("ipex_llm::rwkv_time_shift")
|
||||
def _(hidden, shifted, mix):
|
||||
bsz, seq_len, hidden_size = hidden.shape
|
||||
return torch.empty([mix.size(0), bsz, seq_len, hidden_size],
|
||||
dtype=hidden.dtype, device=hidden.device)
|
||||
|
||||
|
||||
@torch.library.register_fake("ipex_llm::dequantize_rows")
|
||||
def _(x, weight, qtype, state_size, output_size):
|
||||
return torch.empty([x.size(0), x.size(1), state_size],
|
||||
dtype=torch.float, device=weight.device)
|
||||
|
||||
|
||||
# @torch.library.register_fake("ipex_llm::batch_forward")
|
||||
# def _(x, weight, qtype):
|
||||
# return ???
|
||||
|
||||
|
||||
# @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5")
|
||||
# def _(time_decay, time_first, receptance, key, value, state)
|
||||
# return ???
|
||||
|
||||
|
||||
# @torch.library.register_fake("ipex_llm::rwkv_time_shift")
|
||||
# def _(hidden, shifted, mix):
|
||||
# return ???
|
||||
|
||||
|
||||
# @torch.library.register_fake("ipex_llm::dequantize_rows")
|
||||
# def _(x, weight, qtype, state_size, output_size):
|
||||
# return ???
|
||||
|
||||
|
||||
@torch.library.register_fake("ipex_llm::batch_forward")
|
||||
def _(x, weight, qtype):
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
@torch.library.register_fake("ipex_llm::sdp")
|
||||
def _(query, key, value, mask):
|
||||
return torch.empty(query.shape, dtype=query.dtype, device=query.device)
|
||||
|
|
|
|||
Loading…
Reference in a new issue