From 9f8b134889744fce0c487ce715a2d0ac7061a6b6 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 3 Jan 2025 16:45:04 +0800 Subject: [PATCH] add ipex-llm custom kernel registration (#12648) --- .../llm/src/ipex_llm/transformers/xpu_ops.py | 155 ++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 python/llm/src/ipex_llm/transformers/xpu_ops.py diff --git a/python/llm/src/ipex_llm/transformers/xpu_ops.py b/python/llm/src/ipex_llm/transformers/xpu_ops.py new file mode 100644 index 00000000..6ee00e1c --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/xpu_ops.py @@ -0,0 +1,155 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import xe_linear +import xe_batch +import xe_addons + + +@torch.library.register_fake("ipex_llm::forward_new") +def _(x, weight, qtype, input_size): + return torch.empty_like(x) + + +# @torch.library.register_fake("ipex_llm::dequant") +# def _(x, weight, qtype): +# return ??? + + +@torch.library.register_fake("ipex_llm::mlp_forward_xpu") +def _(x, weight1, weight2, batch_size, state_size, output_size, act_type, qtype): + return torch.empty_like(x) + + +# @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4") +# def _(time_decay, time_first, key, value, num_state, den_state, max_state) + # return ??? + + +# @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5") +# def _(time_decay, time_first, receptance, key, value, state) + # return ??? + + +# @torch.library.register_fake("ipex_llm::rwkv_time_shift") +# def _(hidden, shifted, mix): + # return ??? + + +# @torch.library.register_fake("ipex_llm::dequantize_rows") +# def _(x, weight, qtype, state_size, output_size): + # return ??? + + +@torch.library.register_fake("ipex_llm::batch_forward") +def _(x, weight, qtype): + return torch.empty_like(x) + + +@torch.library.register_fake("ipex_llm::sdp") +def _(query, key, value, mask): + return torch.empty(query.shape, dtype=query.dtype, device=query.device) + + +@torch.library.register_fake("ipex_llm::sdp_fp8") +def _(query, key, value, mask): + return torch.empty(query.shape, dtype=query.dtype, device=query.device) + + +@torch.library.register_fake("ipex_llm::sdp_causal") +def _(query, key, value, mask, scale): + return torch.empty(query.shape, dtype=query.dtype, device=query.device) + + +@torch.library.register_fake("ipex_llm::sdp_fp8_causal") +def _(query, key, value, mask, scale): + return torch.empty(query.shape, dtype=query.dtype, device=query.device) + + +@torch.library.register_fake("ipex_llm::sdp_non_causal") +def _(query, key, value, mask, scale): + return torch.empty(query.shape, dtype=query.dtype, device=query.device) + + +@torch.library.register_fake("ipex_llm::sdp_fp8_non_causal") +def _(query, key, value, mask, scale): + return torch.empty(query.shape, dtype=query.dtype, device=query.device) + + +@torch.library.register_fake("ipex_llm::siglip_sdp_non_causal") +def _(query, key, value, mask): + return torch.empty(query.shape, dtype=query.dtype, device=query.device) + + +@torch.library.register_fake("ipex_llm::gemma2_sdp") +def _(query, key, value, mask, f1, f2): + return torch.empty(query.shape, dtype=query.dtype, device=query.device) + + +@torch.library.register_fake("ipex_llm::gemma2_sdp_causal") +def _(query, key, value, mask, f1, f2): + return torch.empty(query.shape, dtype=query.dtype, device=query.device) + + +@torch.library.register_fake("ipex_llm::rms_norm") +def _(weight, x, eps): + return torch.empty_like(x) + + +@torch.library.register_fake("ipex_llm::layer_norm") +def _(x, weight, bias, eps): + return torch.empty_like(x) + + +@torch.library.register_fake("ipex_llm::rotary_half_inplaced") +def _(inv_freq, position_ids, query, key): + pass + + +@torch.library.register_fake("ipex_llm::rotary_two_inplaced") +def _(inv_freq, position_ids, query, key): + pass + + +@torch.library.register_fake("ipex_llm::rotary_half_with_cache_inplaced") +def _(query, key, cos, sin): + pass + + +@torch.library.register_fake("ipex_llm::rotary_two_with_cache_inplaced") +def _(query, key, cos, sin, half_layout): + pass + + +@torch.library.register_fake("ipex_llm::mlp_silu_mul_inplaced") +def _(gate, up): + pass + + +@torch.library.register_fake("ipex_llm::quantize_key_value") +def _(key, value, key_output, value_output): + pass + + +@torch.library.register_fake("ipex_llm::dequantize_key_value") +def _(key, value, key_output, value_output): + pass + + +@torch.library.register_fake("ipex_llm::attn_softmax_inplaced") +def _(attn): + pass