diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index eeda4ae5..f04f1317 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -202,9 +202,6 @@ class _BaseAutoModelClass: @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): - if kwargs.pop("torch_dtype", None) not in [None, "auto", torch.float]: - warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used") - # ignore following arguments ignore_argument(kwargs, "model_hub") ignore_argument(kwargs, "lightweight_bmm") @@ -402,6 +399,10 @@ class _BaseAutoModelClass: if dtype_orig is not None: torch.set_default_dtype(dtype_orig) + # set tie_word_embeddings to False to avoid possible lm_head error + if hasattr(model.config, "tie_word_embeddings"): + model.config.tie_word_embeddings = False + ( model, missing_keys,