Fix qwen2 1.5B NPU load error (#12049)

This commit is contained in:
Ruonan Wang 2024-09-09 23:41:18 -07:00 committed by GitHub
parent abc370728c
commit dc4af02b2a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -202,9 +202,6 @@ class _BaseAutoModelClass:
@classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
if kwargs.pop("torch_dtype", None) not in [None, "auto", torch.float]:
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
# ignore following arguments
ignore_argument(kwargs, "model_hub")
ignore_argument(kwargs, "lightweight_bmm")
@ -402,6 +399,10 @@ class _BaseAutoModelClass:
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
# set tie_word_embeddings to False to avoid possible lm_head error
if hasattr(model.config, "tie_word_embeddings"):
model.config.tie_word_embeddings = False
(
model,
missing_keys,