This commit is contained in:
Yishuo Wang 2024-12-30 17:14:25 +08:00 committed by GitHub
parent 2d08155513
commit f289f68d57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 1 additions and 3 deletions

View file

@ -644,7 +644,6 @@ class LowBitLinear(nn.Linear):
if x0.device.type == "xpu":
# GPU logic
try:
import intel_extension_for_pytorch
import xe_linear
from ipex_llm.transformers.models.utils import use_xmx
except ModuleNotFoundError:

View file

@ -346,8 +346,7 @@ def use_decoding_fast_path(proj,
def use_xmx(x: torch.Tensor, qtype: int):
device = get_xpu_device_type(x)
return (
os.environ.get("BIGDL_LLM_XMX_DISABLED", "0") != "1"
and device in ["arc", "flex", "pvc"]
device in ["arc", "flex", "pvc"]
and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5]
and (
(device == "pvc" and 1 < x.size(0) <= 16)