LLM: add enable_xetla parameter for optimize_model API (#10753)

This commit is contained in:
binbin Deng 2024-04-15 12:18:25 +08:00 committed by GitHub
parent 3590e1be83
commit 3d561b60ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -253,7 +253,8 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
optimize_model=optimize_llm,
modules_to_not_convert=modules_to_not_convert,
cpu_embedding=cpu_embedding,
lightweight_bmm=lightweight_bmm)
lightweight_bmm=lightweight_bmm,
enable_xetla=kwargs.pop("enable_xetla", False))
# add save_low_bit to pretrained model dynamically
import types
model._bigdl_config = dict()