Improve baichuan2 NPU performance (#11942)
This commit is contained in:
parent
90f692937d
commit
bec00e2015
3 changed files with 22 additions and 3 deletions
|
|
@ -150,9 +150,10 @@ class _BaseAutoModelClass:
|
|||
" than max_output_len ({max_output_len})"
|
||||
),
|
||||
)
|
||||
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
|
||||
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm, optimize_llm_pre
|
||||
|
||||
with torch.no_grad():
|
||||
optimize_llm_pre(model)
|
||||
cls.load_convert(qtype, model, "cpu", *args, **kwargs)
|
||||
create_npu_kernels(model)
|
||||
model = model.eval()
|
||||
|
|
|
|||
|
|
@ -279,9 +279,7 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
|
|||
proj = self.unsqueeze(proj, [0]) # b, s, 3, h
|
||||
proj = self.transpose(proj, [2, 1, 0, 3]) # 3, s, b, h
|
||||
proj = self.squeeze(proj) # 3, b*s, h
|
||||
print("proj shape: ", proj.shape)
|
||||
proj = self.unsqueeze(proj, [1])
|
||||
print("proj shape after unsqueeze", proj.shape)
|
||||
# query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
query_states = self.reshape(proj[0, ...], [self.batch_size,
|
||||
self.seq_len, self.num_heads, self.head_dim])
|
||||
|
|
|
|||
|
|
@ -25,6 +25,22 @@ def convert_forward(m, target_m, new_forward):
|
|||
convert_forward(sub_m, target_m, new_forward)
|
||||
|
||||
|
||||
def optimize_llm_pre(model: torch.nn.Module):
|
||||
if model.config.model_type == "baichuan":
|
||||
# process NormHead module in Baichuan2 7B
|
||||
if hasattr(model, 'lm_head') and model.lm_head is not None:
|
||||
vocab_size, hidden_size = model.lm_head.weight.shape
|
||||
lm_head_weight_data = model.lm_head.weight.data
|
||||
model.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False,
|
||||
device=lm_head_weight_data.device)
|
||||
if model.lm_head.weight.data.device != "meta":
|
||||
norm_weight = torch.nn.functional.normalize(lm_head_weight_data)
|
||||
model.lm_head.weight.data = norm_weight
|
||||
if model.config.hidden_size in [4096, 2048]:
|
||||
from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
|
||||
model.apply(pre_compute_inv_freq)
|
||||
|
||||
|
||||
def optimize_llm(
|
||||
model: torch.nn.Module,
|
||||
max_output_len=1024,
|
||||
|
|
@ -126,6 +142,10 @@ def optimize_llm(
|
|||
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
|
||||
elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32:
|
||||
# for Baichuan2-7B
|
||||
if intra_pp is None:
|
||||
intra_pp = 2
|
||||
if inter_pp is None:
|
||||
inter_pp = 2
|
||||
from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward
|
||||
from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner
|
||||
decode_runner = DecodeRunner(
|
||||
|
|
|
|||
Loading…
Reference in a new issue