parent
f8dcaff7f4
commit
ad050107b3
1 changed files with 5 additions and 1 deletions
|
|
@ -361,6 +361,10 @@ class _BaseAutoModelClass:
|
||||||
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
|
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
|
||||||
torch_dtype=kwargs.get("torch_dtype", 'auto'))
|
torch_dtype=kwargs.get("torch_dtype", 'auto'))
|
||||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||||
|
|
||||||
|
# enable tie_word_embeddings for MPT
|
||||||
|
# refer to https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/modeling_mpt.py#L232
|
||||||
|
if model.config.architectures[0] != 'MPTForCausalLM':
|
||||||
model.config.update({"tie_word_embeddings": False})
|
model.config.update({"tie_word_embeddings": False})
|
||||||
|
|
||||||
# add save_low_bit to pretrained model dynamically
|
# add save_low_bit to pretrained model dynamically
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue