LLM: fix mpt load_low_bit issue (#10075)

* fix

* retry

* retry
This commit is contained in:
Jin Qiao 2024-02-05 10:17:07 +08:00 committed by GitHub
parent f8dcaff7f4
commit ad050107b3

View file

@ -361,7 +361,11 @@ class _BaseAutoModelClass:
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
torch_dtype=kwargs.get("torch_dtype", 'auto'))
model.config.update({"bigdl_transformers_low_bit": q_k})
model.config.update({"tie_word_embeddings": False})
# enable tie_word_embeddings for MPT
# refer to https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/modeling_mpt.py#L232
if model.config.architectures[0] != 'MPTForCausalLM':
model.config.update({"tie_word_embeddings": False})
# add save_low_bit to pretrained model dynamically
import types