diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 93afdbe9..1e8e431b 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -361,7 +361,11 @@ class _BaseAutoModelClass: cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm, torch_dtype=kwargs.get("torch_dtype", 'auto')) model.config.update({"bigdl_transformers_low_bit": q_k}) - model.config.update({"tie_word_embeddings": False}) + + # enable tie_word_embeddings for MPT + # refer to https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/modeling_mpt.py#L232 + if model.config.architectures[0] != 'MPTForCausalLM': + model.config.update({"tie_word_embeddings": False}) # add save_low_bit to pretrained model dynamically import types