add fp6 mlp fusion (#11032)
* add fp6 fusion * add qkv fusion for fp6 * remove qkv first
This commit is contained in:
parent
2084ebe4ee
commit
ac384e0f45
1 changed files with 2 additions and 2 deletions
|
|
@ -20,7 +20,7 @@ import warnings
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type
|
from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type
|
||||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4
|
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4, FP6
|
||||||
from ipex_llm.transformers.convert import is_deepspeed_available
|
from ipex_llm.transformers.convert import is_deepspeed_available
|
||||||
|
|
||||||
FP8_KV_ALLOC_LENGTH = 512
|
FP8_KV_ALLOC_LENGTH = 512
|
||||||
|
|
@ -410,7 +410,7 @@ def mlp_fusion_check(x, qtype, training):
|
||||||
return False
|
return False
|
||||||
if x.device.type != 'xpu':
|
if x.device.type != 'xpu':
|
||||||
return False
|
return False
|
||||||
if qtype not in [SYM_INT4, FP8E5, FP4, IQ2_XXS]:
|
if qtype not in [SYM_INT4, FP8E5, FP4, IQ2_XXS, FP6]:
|
||||||
return False
|
return False
|
||||||
if training or x.requires_grad:
|
if training or x.requires_grad:
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue