diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md index 351311b3..2127a34d 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md @@ -86,11 +86,18 @@ The example below shows how to run the **_optimized model implementations_** on - [MiniCPM-2B](./minicpm.py) - [Baichuan2-7B](./baichuan2.py) +### Recommended NPU Driver Version for LNL Users +#### 32.0.100.2625 +Supported models: Llama2-7B, Qwen2-1.5B, Qwen2-7B, MiniCPM-1B, Baichuan2-7B +#### 32.0.101.2715 +Supported models: Llama3-8B, MiniCPM-2B + +### Run Models ```bash # to run Llama-2-7b-chat-hf python llama.py -# to run Meta-Llama-3-8B-Instruct +# to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715) python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct # to run Qwen2-1.5B-Instruct @@ -124,7 +131,7 @@ If you encounter output problem, please try to disable the optimization of trans # to run Llama-2-7b-chat-hf python  llama.py --disable-transpose-value-cache -# to run Meta-Llama-3-8B-Instruct +# to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715) python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct --disable-transpose-value-cache # to run Qwen2-1.5B-Instruct diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 342d9832..a1b07a8c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -44,29 +44,26 @@ def optimize_llm_pre(model: torch.nn.Module, qtype): # lm_head to cpu optimization if os.environ.get("IPEX_LLM_CPU_LM_HEAD", "1") != "0": - is_unsupported_model = (model.config.model_type == "llama" - and model.vocab_size > 32000) - if not is_unsupported_model: - from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8 - if qtype == "sym_int4_rtn": - lm_qtype = SYM_INT4 - else: - lm_qtype = SYM_INT8 - # lm_head opt to mp opt (llama, qwen2) - optimize_lm_head = model.config.model_type not in ["llama", "qwen2"] - new_linear = LowBitLinear(model.lm_head.in_features, - model.lm_head.out_features, - lm_qtype, - False, - optimize_lm_head=optimize_lm_head) - paramsLowBit = FP4Params(data=model.lm_head.weight.data, - requires_grad=False, - quantized=False, - _shape=None, - qtype=lm_qtype, - in_features=model.lm_head.in_features).to("cpu") - new_linear._parameters['weight'] = paramsLowBit - model.lm_head = new_linear + from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8 + if qtype == "sym_int4_rtn": + lm_qtype = SYM_INT4 + else: + lm_qtype = SYM_INT8 + # lm_head opt to mp opt (llama, qwen2) + optimize_lm_head = model.config.model_type not in ["llama", "qwen2"] + new_linear = LowBitLinear(model.lm_head.in_features, + model.lm_head.out_features, + lm_qtype, + False, + optimize_lm_head=optimize_lm_head) + paramsLowBit = FP4Params(data=model.lm_head.weight.data, + requires_grad=False, + quantized=False, + _shape=None, + qtype=lm_qtype, + in_features=model.lm_head.in_features).to("cpu") + new_linear._parameters['weight'] = paramsLowBit + model.lm_head = new_linear def optimize_llm(