Fix qwen2 1.5B NPU load error (#12049)
This commit is contained in:
parent
abc370728c
commit
dc4af02b2a
1 changed files with 4 additions and 3 deletions
|
|
@ -202,9 +202,6 @@ class _BaseAutoModelClass:
|
||||||
@classmethod
|
@classmethod
|
||||||
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||||
def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
|
def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
|
||||||
if kwargs.pop("torch_dtype", None) not in [None, "auto", torch.float]:
|
|
||||||
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
|
|
||||||
|
|
||||||
# ignore following arguments
|
# ignore following arguments
|
||||||
ignore_argument(kwargs, "model_hub")
|
ignore_argument(kwargs, "model_hub")
|
||||||
ignore_argument(kwargs, "lightweight_bmm")
|
ignore_argument(kwargs, "lightweight_bmm")
|
||||||
|
|
@ -402,6 +399,10 @@ class _BaseAutoModelClass:
|
||||||
if dtype_orig is not None:
|
if dtype_orig is not None:
|
||||||
torch.set_default_dtype(dtype_orig)
|
torch.set_default_dtype(dtype_orig)
|
||||||
|
|
||||||
|
# set tie_word_embeddings to False to avoid possible lm_head error
|
||||||
|
if hasattr(model.config, "tie_word_embeddings"):
|
||||||
|
model.config.tie_word_embeddings = False
|
||||||
|
|
||||||
(
|
(
|
||||||
model,
|
model,
|
||||||
missing_keys,
|
missing_keys,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue