Fix multiple get_enable_ipex function error (#10400)

* fix multiple get_enable_ipex function error

* remove get_enable_ipex_low_bit function
This commit is contained in:
ZehuaCao 2024-03-14 10:14:13 +08:00 committed by GitHub
parent 76e30d8ec8
commit f66329e35d
2 changed files with 1 additions and 9 deletions

View file

@ -634,14 +634,6 @@ def _optimize_pre(model):
return model
def get_enable_ipex(low_bit):
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
qtype = ggml_tensor_qtype[low_bit]
_enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
return _enable_ipex
def ggml_convert_low_bit(model, qtype, optimize_model=True,
convert_shape_only=False, device="cpu",
modules_to_not_convert=None, cpu_embedding=False,

View file

@ -65,7 +65,7 @@ def load_model(
# Load tokenizer
tokenizer = tokenizer_cls.from_pretrained(model_path, trust_remote_code=True)
model = model_cls.from_pretrained(model_path, **model_kwargs)
if not get_enable_ipex(low_bit):
if not get_enable_ipex():
model = model.eval()
if device == "xpu":