fix custom kernel registration (#12674)
This commit is contained in:
parent
a22a8c21bb
commit
5c24276fc4
1 changed files with 25 additions and 19 deletions
|
|
@ -20,9 +20,9 @@ import xe_batch
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
|
||||||
|
|
||||||
@torch.library.register_fake("ipex_llm::forward_new")
|
# @torch.library.register_fake("ipex_llm::forward_new")
|
||||||
def _(x, weight, qtype, input_size):
|
# def _(x, weight, qtype, input_size):
|
||||||
return torch.empty_like(x)
|
# return ???
|
||||||
|
|
||||||
|
|
||||||
# @torch.library.register_fake("ipex_llm::dequant")
|
# @torch.library.register_fake("ipex_llm::dequant")
|
||||||
|
|
@ -32,34 +32,40 @@ def _(x, weight, qtype, input_size):
|
||||||
|
|
||||||
@torch.library.register_fake("ipex_llm::mlp_forward_xpu")
|
@torch.library.register_fake("ipex_llm::mlp_forward_xpu")
|
||||||
def _(x, weight1, weight2, batch_size, state_size, output_size, act_type, qtype):
|
def _(x, weight1, weight2, batch_size, state_size, output_size, act_type, qtype):
|
||||||
return torch.empty_like(x)
|
return torch.empty([batch_size, output_size],
|
||||||
|
dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
# @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4")
|
@torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4")
|
||||||
# def _(time_decay, time_first, key, value, num_state, den_state, max_state)
|
def _(time_decay, time_first, key, value, num_state, den_state, max_state):
|
||||||
|
return torch.empty_like(key)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5")
|
||||||
|
def _(time_decay, time_first, receptance, key, value, state):
|
||||||
|
bsz, n_heads, seq_len, head_dim = key.shape
|
||||||
|
return torch.empty([bsz, seq_len, n_heads, head_dim],
|
||||||
|
dtype=key.dtype, device=key.device)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.register_fake("ipex_llm::rwkv_time_shift")
|
||||||
|
def _(hidden, shifted, mix):
|
||||||
|
bsz, seq_len, hidden_size = hidden.shape
|
||||||
|
return torch.empty([mix.size(0), bsz, seq_len, hidden_size],
|
||||||
|
dtype=hidden.dtype, device=hidden.device)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.register_fake("ipex_llm::dequantize_rows")
|
||||||
|
def _(x, weight, qtype, state_size, output_size):
|
||||||
|
return torch.empty([x.size(0), x.size(1), state_size],
|
||||||
|
dtype=torch.float, device=weight.device)
|
||||||
|
|
||||||
|
|
||||||
|
# @torch.library.register_fake("ipex_llm::batch_forward")
|
||||||
|
# def _(x, weight, qtype):
|
||||||
# return ???
|
# return ???
|
||||||
|
|
||||||
|
|
||||||
# @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5")
|
|
||||||
# def _(time_decay, time_first, receptance, key, value, state)
|
|
||||||
# return ???
|
|
||||||
|
|
||||||
|
|
||||||
# @torch.library.register_fake("ipex_llm::rwkv_time_shift")
|
|
||||||
# def _(hidden, shifted, mix):
|
|
||||||
# return ???
|
|
||||||
|
|
||||||
|
|
||||||
# @torch.library.register_fake("ipex_llm::dequantize_rows")
|
|
||||||
# def _(x, weight, qtype, state_size, output_size):
|
|
||||||
# return ???
|
|
||||||
|
|
||||||
|
|
||||||
@torch.library.register_fake("ipex_llm::batch_forward")
|
|
||||||
def _(x, weight, qtype):
|
|
||||||
return torch.empty_like(x)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.library.register_fake("ipex_llm::sdp")
|
@torch.library.register_fake("ipex_llm::sdp")
|
||||||
def _(query, key, value, mask):
|
def _(query, key, value, mask):
|
||||||
return torch.empty(query.shape, dtype=query.dtype, device=query.device)
|
return torch.empty(query.shape, dtype=query.dtype, device=query.device)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue