diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 3b1de0ab..cc273dcd 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -153,7 +153,7 @@ class _BaseAutoModelClass: from ipex_llm.transformers.npu_models.convert_mp import optimize_llm, optimize_llm_pre with torch.no_grad(): - optimize_llm_pre(model) + optimize_llm_pre(model, qtype) cls.load_convert(qtype, model, "cpu", *args, **kwargs) create_npu_kernels(model) model = model.eval() diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index edc76687..ccc1ffca 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -56,7 +56,7 @@ def replace_with_QuantizedLinear(layer, qtype, device): from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype iqtype = ggml_tensor_qtype[qtype] - if isinstance(layer, torch.nn.Linear): + if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"): if qtype == "sym_int4_rtn": # workaround for qwen2 & int4 if (layer.in_features == 3584 and layer.out_features == 152064) or \ diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 5e755085..1964b754 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import torch import importlib +from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params def convert_forward(m, target_m, new_forward): @@ -25,7 +27,7 @@ def convert_forward(m, target_m, new_forward): convert_forward(sub_m, target_m, new_forward) -def optimize_llm_pre(model: torch.nn.Module): +def optimize_llm_pre(model: torch.nn.Module, qtype): if model.config.model_type == "baichuan": # process NormHead module in Baichuan2 7B if hasattr(model, 'lm_head') and model.lm_head is not None: @@ -40,6 +42,32 @@ def optimize_llm_pre(model: torch.nn.Module): from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq model.apply(pre_compute_inv_freq) + # lm_head to cpu optimization + if os.environ.get("IPEX_LLM_CPU_LM_HEAD", "1") != "0": + is_unsupported_model = (model.config.model_type == "llama" + and model.vocab_size > 32000) + if not is_unsupported_model: + from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8 + if qtype == "sym_int4_rtn": + lm_qtype = SYM_INT4 + else: + lm_qtype = SYM_INT8 + # lm_head opt to mp opt (llama, qwen2) + optimize_lm_head = model.config.model_type not in ["llama", "qwen2"] + new_linear = LowBitLinear(model.lm_head.in_features, + model.lm_head.out_features, + lm_qtype, + False, + optimize_lm_head=optimize_lm_head) + paramsLowBit = FP4Params(data=model.lm_head.weight.data, + requires_grad=False, + quantized=False, + _shape=None, + qtype=lm_qtype, + in_features=model.lm_head.in_features).to("cpu") + new_linear._parameters['weight'] = paramsLowBit + model.lm_head = new_linear + def optimize_llm( model: torch.nn.Module,