diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 64e18511..b12f4db5 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -309,9 +309,13 @@ def _optimize_pre(model): if hasattr(model, 'lm_head') and model.lm_head is not None: # do we need to check the class instance? vocab_size, hidden_size = model.lm_head.weight.shape - norm_weight = nn.functional.normalize(model.lm_head.weight.data) - model.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) - model.lm_head.weight.data = norm_weight + lm_head_weight_data = model.lm_head.weight.data + model.lm_head = nn.Linear(hidden_size, vocab_size, bias=False, + device=lm_head_weight_data.device) + # In which case we are NOT loading the normalized weights + if model.lm_head.weight.data.device != "meta": + norm_weight = nn.functional.normalize(lm_head_weight_data) + model.lm_head.weight.data = norm_weight return model