From bec00e2015c1797a08e0a575d8d73551fb404e88 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Tue, 27 Aug 2024 18:37:08 +0800 Subject: [PATCH] Improve baichuan2 NPU performance (#11942) --- .../src/ipex_llm/transformers/npu_model.py | 3 ++- .../transformers/npu_models/baichuan_mp.py | 2 -- .../transformers/npu_models/convert_mp.py | 20 +++++++++++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 96551d48..3b1de0ab 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -150,9 +150,10 @@ class _BaseAutoModelClass: " than max_output_len ({max_output_len})" ), ) - from ipex_llm.transformers.npu_models.convert_mp import optimize_llm + from ipex_llm.transformers.npu_models.convert_mp import optimize_llm, optimize_llm_pre with torch.no_grad(): + optimize_llm_pre(model) cls.load_convert(qtype, model, "cpu", *args, **kwargs) create_npu_kernels(model) model = model.eval() diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index 37767402..b5a09dcb 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -279,9 +279,7 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): proj = self.unsqueeze(proj, [0]) # b, s, 3, h proj = self.transpose(proj, [2, 1, 0, 3]) # 3, s, b, h proj = self.squeeze(proj) # 3, b*s, h - print("proj shape: ", proj.shape) proj = self.unsqueeze(proj, [1]) - print("proj shape after unsqueeze", proj.shape) # query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = self.reshape(proj[0, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 3d74880b..5e755085 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -25,6 +25,22 @@ def convert_forward(m, target_m, new_forward): convert_forward(sub_m, target_m, new_forward) +def optimize_llm_pre(model: torch.nn.Module): + if model.config.model_type == "baichuan": + # process NormHead module in Baichuan2 7B + if hasattr(model, 'lm_head') and model.lm_head is not None: + vocab_size, hidden_size = model.lm_head.weight.shape + lm_head_weight_data = model.lm_head.weight.data + model.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False, + device=lm_head_weight_data.device) + if model.lm_head.weight.data.device != "meta": + norm_weight = torch.nn.functional.normalize(lm_head_weight_data) + model.lm_head.weight.data = norm_weight + if model.config.hidden_size in [4096, 2048]: + from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq + model.apply(pre_compute_inv_freq) + + def optimize_llm( model: torch.nn.Module, max_output_len=1024, @@ -126,6 +142,10 @@ def optimize_llm( convert_forward(model, module.MiniCPMModel, minicpm_model_forward) elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32: # for Baichuan2-7B + if intra_pp is None: + intra_pp = 2 + if inter_pp is None: + inter_pp = 2 from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner decode_runner = DecodeRunner(