small fix (#12634)
This commit is contained in:
parent
2d08155513
commit
f289f68d57
2 changed files with 1 additions and 3 deletions
|
|
@ -644,7 +644,6 @@ class LowBitLinear(nn.Linear):
|
||||||
if x0.device.type == "xpu":
|
if x0.device.type == "xpu":
|
||||||
# GPU logic
|
# GPU logic
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch
|
|
||||||
import xe_linear
|
import xe_linear
|
||||||
from ipex_llm.transformers.models.utils import use_xmx
|
from ipex_llm.transformers.models.utils import use_xmx
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
|
|
|
||||||
|
|
@ -346,8 +346,7 @@ def use_decoding_fast_path(proj,
|
||||||
def use_xmx(x: torch.Tensor, qtype: int):
|
def use_xmx(x: torch.Tensor, qtype: int):
|
||||||
device = get_xpu_device_type(x)
|
device = get_xpu_device_type(x)
|
||||||
return (
|
return (
|
||||||
os.environ.get("BIGDL_LLM_XMX_DISABLED", "0") != "1"
|
device in ["arc", "flex", "pvc"]
|
||||||
and device in ["arc", "flex", "pvc"]
|
|
||||||
and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5]
|
and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5]
|
||||||
and (
|
and (
|
||||||
(device == "pvc" and 1 < x.size(0) <= 16)
|
(device == "pvc" and 1 < x.size(0) <= 16)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue