From 828fa01ad350cbfadb7add7cfa15c1e76ac8933a Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Fri, 20 Sep 2024 16:36:21 +0800 Subject: [PATCH] [NPU] Add `mixed_precision` for Qwen2 7B (#12098) * Add mix_precision argument to control whether use INT8 lm_head for Qwen2-7B-Instruct * Small fix * Fixed on load low bit with mixed precision * Small fix * Update example accordingly * Update for default prompt * Update base on comments * Final fix --- .../HF-Transformers-AutoModels/LLM/README.md | 42 +++++++++++-------- .../HF-Transformers-AutoModels/LLM/qwen2.py | 4 +- .../src/ipex_llm/transformers/npu_model.py | 12 ++++-- .../transformers/npu_models/convert_mp.py | 7 +++- 4 files changed, 41 insertions(+), 24 deletions(-) diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md index e6519dc7..bd8bc6e8 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md @@ -26,11 +26,11 @@ Right click and select **Update Driver** -> **Browse my computer for drivers**. ## 1. Install ### 1.1 Installation on Windows We suggest using conda to manage environment: -```bash +```cmd conda create -n llm python=3.10 conda activate llm -# install ipex-llm with 'npu' option +:: install ipex-llm with 'npu' option pip install --pre --upgrade ipex-llm[npu] ``` @@ -98,26 +98,26 @@ Supported models: Llama2-7B, MiniCPM-1B, Baichuan2-7B Supported models: Llama3-8B, MiniCPM-2B, Qwen2-7B, Qwen2-1.5B ### Run -```bash -# to run Llama-2-7b-chat-hf +```cmd +:: to run Llama-2-7b-chat-hf python llama.py -# to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715) +:: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715) python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct -# to run Qwen2-1.5B-Instruct LNL driver version: 32.0.101.2715) +:: to run Qwen2-1.5B-Instruct LNL driver version: 32.0.101.2715) python qwen2.py -# to run Qwen2-7B-Instruct LNL driver version: 32.0.101.2715) +:: to run Qwen2-7B-Instruct LNL driver version: 32.0.101.2715) python qwen2.py --repo-id-or-model-path Qwen/Qwen2-7B-Instruct -# to run MiniCPM-1B-sft-bf16 +:: to run MiniCPM-1B-sft-bf16 python minicpm.py -# to run MiniCPM-2B-sft-bf16 (LNL driver version: 32.0.101.2715) +:: to run MiniCPM-2B-sft-bf16 (LNL driver version: 32.0.101.2715) python minicpm.py --repo-id-or-model-path openbmb/MiniCPM-2B-sft-bf16 -# to run Baichuan2-7B-Chat +:: to run Baichuan2-7B-Chat python baichuan2.py ``` @@ -137,29 +137,35 @@ If you encounter `TypeError: can't convert meta device type tensor to numpy. Use #### Output Problem If you encounter output problem, please try to disable the optimization of transposing value cache with following command: -```bash -# to run Llama-2-7b-chat-hf +```cmd +:: to run Llama-2-7b-chat-hf python llama.py --disable-transpose-value-cache -# to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715) +:: to run Meta-Llama-3-8B-Instruct (LNL driver version: 32.0.101.2715) python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct --disable-transpose-value-cache -# to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715) +:: to run Qwen2-1.5B-Instruct (LNL driver version: 32.0.101.2715) python qwen2.py --disable-transpose-value-cache -# to run Qwen2-7B-Instruct LNL driver version: 32.0.101.2715) +:: to run Qwen2-7B-Instruct LNL driver version: 32.0.101.2715) python qwen2.py --repo-id-or-model-path Qwen/Qwen2-7B-Instruct --disable-transpose-value-cache -# to run MiniCPM-1B-sft-bf16 +:: to run MiniCPM-1B-sft-bf16 python minicpm.py --disable-transpose-value-cache -# to run MiniCPM-2B-sft-bf16 (LNL driver version: 32.0.101.2715) +:: to run MiniCPM-2B-sft-bf16 (LNL driver version: 32.0.101.2715) python minicpm.py --repo-id-or-model-path openbmb/MiniCPM-2B-sft-bf16 --disable-transpose-value-cache -# to run Baichuan2-7B-Chat +:: to run Baichuan2-7B-Chat python baichuan2.py --disable-transpose-value-cache ``` +For [Qwen2-7B](./qwen2.py), you could also try to enable mixed precision optimization when encountering output problems: + +```cmd +python qwen2.py --repo-id-or-model-path Qwen/Qwen2-7B-Instruct --mixed-precision +``` + #### Better Performance with High CPU Utilization You could enable optimization by setting the environment variable with `set IPEX_LLM_CPU_LM_HEAD=1` for better performance. But this will cause high CPU utilization. diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/qwen2.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/qwen2.py index 465a9910..9eec34a0 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/qwen2.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/qwen2.py @@ -43,7 +43,7 @@ if __name__ == "__main__": If path not exists, lowbit model will be saved there. \ Else, lowbit model will be loaded.", ) - parser.add_argument('--prompt', type=str, default="What is AI?", + parser.add_argument('--prompt', type=str, default="AI是什么?", help='Prompt to infer') parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") parser.add_argument("--max-output-len", type=int, default=1024) @@ -51,6 +51,7 @@ if __name__ == "__main__": parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--intra-pp", type=int, default=None) parser.add_argument("--inter-pp", type=int, default=None) + parser.add_argument("--mixed-precision", action='store_true') args = parser.parse_args() model_path = args.repo_id_or_model_path @@ -68,6 +69,7 @@ if __name__ == "__main__": intra_pp=args.intra_pp, inter_pp=args.inter_pp, transpose_value_cache=not args.disable_transpose_value_cache, + mixed_precision=args.mixed_precision ) else: model = AutoModelForCausalLM.load_low_bit( diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 4ac5f3a9..56ca664c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -78,6 +78,9 @@ class _BaseAutoModelClass: Relevant low bit optimizations will be applied to the model. :param optimize_model: boolean value, Whether to further optimize the low_bit llm model. Default to be ``False``. + :param mixed_precision: boolean value, Whether to use mixed precision quantization. + Default to be False. If set to ``True``, we will use ``'sym_int8'`` for lm_head when + ``load_in_low_bit`` is '``sym_int4``' for certain models. :return: a model instance """ if kwargs.get("device_map", None) not in [None, "cpu", "auto"]: @@ -108,7 +111,6 @@ class _BaseAutoModelClass: ignore_argument(kwargs, "load_in_4bit") ignore_argument(kwargs, "load_in_8bit") ignore_argument(kwargs, "imatrix") - ignore_argument(kwargs, "mixed_precision") ignore_argument(kwargs, "cpu_embedding") ignore_argument(kwargs, "embedding_qtype") ignore_argument(kwargs, "enable_mp") @@ -123,6 +125,7 @@ class _BaseAutoModelClass: intra_pp = kwargs.pop("intra_pp", None) transpose_value_cache = kwargs.pop("transpose_value_cache", True) modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) + mixed_precision = kwargs.pop('mixed_precision', False) _args = copy.deepcopy(args) _kwargs = copy.deepcopy(kwargs) @@ -158,7 +161,8 @@ class _BaseAutoModelClass: llm = model with torch.no_grad(): - optimize_llm_pre(model, qtype) + model.config.update({"mixed_precision": mixed_precision}) + optimize_llm_pre(model, qtype, mixed_precision) cls.load_convert(qtype, model, "cpu", modules_to_not_convert, *args, **kwargs) create_npu_kernels(llm) model = model.eval() @@ -209,6 +213,7 @@ class _BaseAutoModelClass: ignore_argument(kwargs, "embedding_qtype") ignore_argument(kwargs, "speculative") ignore_argument(kwargs, "pipeline_parallel_stages") + ignore_argument(kwargs, "mixed_precision") optimize_model = kwargs.pop("optimize_model", False) max_output_len = kwargs.pop("max_output_len", 1024) max_prompt_len = kwargs.pop("max_prompt_len", 512) @@ -258,6 +263,7 @@ class _BaseAutoModelClass: config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path) qtype = config_dict.pop("bigdl_transformers_low_bit", False) bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True) + mixed_precision = config_dict.pop("mixed_precision", False) invalidInputError( qtype, @@ -370,7 +376,7 @@ class _BaseAutoModelClass: llm = model with torch.no_grad(): - optimize_llm_pre(model, qtype) + optimize_llm_pre(model, qtype, mixed_precision) cls.load_convert(qtype, model, quant_device, modules_to_not_convert, *model_args, **kwargs) create_npu_kernels(llm) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 781b0c03..47c94782 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -29,7 +29,7 @@ def convert_forward(m, target_m, new_forward): convert_forward(sub_m, target_m, new_forward) -def optimize_llm_pre(model: torch.nn.Module, qtype): +def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision): if model.config.model_type == "baichuan": # process NormHead module in Baichuan2 7B if hasattr(model, 'lm_head') and model.lm_head is not None: @@ -92,7 +92,10 @@ def optimize_llm_pre(model: torch.nn.Module, qtype): # for Qwen2-7B-Insturct, divide lm_head into 14 parts if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \ not cpu_lm_head: - new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=14, + # Do not split lm_head and use sym_int8 instead when mixed_precison is True + is_split = (not mixed_precision) and qtype == "sym_int4_rtn" + split_num = 14 if is_split else 1 + new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num, bias=model.lm_head.bias) del model.lm_head model.lm_head = new_lm_head