Improve baichuan2 NPU performance (#11942)

This commit is contained in:
binbin Deng 2024-08-27 18:37:08 +08:00 committed by GitHub
parent 90f692937d
commit bec00e2015
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 22 additions and 3 deletions

View file

@ -150,9 +150,10 @@ class _BaseAutoModelClass:
" than max_output_len ({max_output_len})"
),
)
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm, optimize_llm_pre
with torch.no_grad():
optimize_llm_pre(model)
cls.load_convert(qtype, model, "cpu", *args, **kwargs)
create_npu_kernels(model)
model = model.eval()

View file

@ -279,9 +279,7 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
proj = self.unsqueeze(proj, [0]) # b, s, 3, h
proj = self.transpose(proj, [2, 1, 0, 3]) # 3, s, b, h
proj = self.squeeze(proj) # 3, b*s, h
print("proj shape: ", proj.shape)
proj = self.unsqueeze(proj, [1])
print("proj shape after unsqueeze", proj.shape)
# query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = self.reshape(proj[0, ...], [self.batch_size,
self.seq_len, self.num_heads, self.head_dim])

View file

@ -25,6 +25,22 @@ def convert_forward(m, target_m, new_forward):
convert_forward(sub_m, target_m, new_forward)
def optimize_llm_pre(model: torch.nn.Module):
if model.config.model_type == "baichuan":
# process NormHead module in Baichuan2 7B
if hasattr(model, 'lm_head') and model.lm_head is not None:
vocab_size, hidden_size = model.lm_head.weight.shape
lm_head_weight_data = model.lm_head.weight.data
model.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False,
device=lm_head_weight_data.device)
if model.lm_head.weight.data.device != "meta":
norm_weight = torch.nn.functional.normalize(lm_head_weight_data)
model.lm_head.weight.data = norm_weight
if model.config.hidden_size in [4096, 2048]:
from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
model.apply(pre_compute_inv_freq)
def optimize_llm(
model: torch.nn.Module,
max_output_len=1024,
@ -126,6 +142,10 @@ def optimize_llm(
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32:
# for Baichuan2-7B
if intra_pp is None:
intra_pp = 2
if inter_pp is None:
inter_pp = 2
from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward
from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner
decode_runner = DecodeRunner(