diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 22320a4b..58af7a9c 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -358,6 +358,10 @@ def mlp_fusion_check(x, qtype, training): return False if training or x.requires_grad: return False + if qtype == FP6: + device = get_xpu_device_type(x) + if device == "mtl": + return False return True