add fp6 mlp fusion (#11032)

* add fp6 fusion

* add qkv fusion for fp6

* remove qkv first
This commit is contained in:
Ruonan Wang 2024-05-15 17:42:50 +08:00 committed by GitHub
parent 2084ebe4ee
commit ac384e0f45
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -20,7 +20,7 @@ import warnings
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from ipex_llm.ggml.quantize import ggml_tensor_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype
from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4 from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4, FP6
from ipex_llm.transformers.convert import is_deepspeed_available from ipex_llm.transformers.convert import is_deepspeed_available
FP8_KV_ALLOC_LENGTH = 512 FP8_KV_ALLOC_LENGTH = 512
@ -410,7 +410,7 @@ def mlp_fusion_check(x, qtype, training):
return False return False
if x.device.type != 'xpu': if x.device.type != 'xpu':
return False return False
if qtype not in [SYM_INT4, FP8E5, FP4, IQ2_XXS]: if qtype not in [SYM_INT4, FP8E5, FP4, IQ2_XXS, FP6]:
return False return False
if training or x.requires_grad: if training or x.requires_grad:
return False return False