Improve baichuan2 NPU performance (#11942)
This commit is contained in:
parent
90f692937d
commit
bec00e2015
3 changed files with 22 additions and 3 deletions
|
|
@ -150,9 +150,10 @@ class _BaseAutoModelClass:
|
||||||
" than max_output_len ({max_output_len})"
|
" than max_output_len ({max_output_len})"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
|
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm, optimize_llm_pre
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
optimize_llm_pre(model)
|
||||||
cls.load_convert(qtype, model, "cpu", *args, **kwargs)
|
cls.load_convert(qtype, model, "cpu", *args, **kwargs)
|
||||||
create_npu_kernels(model)
|
create_npu_kernels(model)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
|
||||||
|
|
@ -279,9 +279,7 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
|
||||||
proj = self.unsqueeze(proj, [0]) # b, s, 3, h
|
proj = self.unsqueeze(proj, [0]) # b, s, 3, h
|
||||||
proj = self.transpose(proj, [2, 1, 0, 3]) # 3, s, b, h
|
proj = self.transpose(proj, [2, 1, 0, 3]) # 3, s, b, h
|
||||||
proj = self.squeeze(proj) # 3, b*s, h
|
proj = self.squeeze(proj) # 3, b*s, h
|
||||||
print("proj shape: ", proj.shape)
|
|
||||||
proj = self.unsqueeze(proj, [1])
|
proj = self.unsqueeze(proj, [1])
|
||||||
print("proj shape after unsqueeze", proj.shape)
|
|
||||||
# query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
# query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
query_states = self.reshape(proj[0, ...], [self.batch_size,
|
query_states = self.reshape(proj[0, ...], [self.batch_size,
|
||||||
self.seq_len, self.num_heads, self.head_dim])
|
self.seq_len, self.num_heads, self.head_dim])
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,22 @@ def convert_forward(m, target_m, new_forward):
|
||||||
convert_forward(sub_m, target_m, new_forward)
|
convert_forward(sub_m, target_m, new_forward)
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_llm_pre(model: torch.nn.Module):
|
||||||
|
if model.config.model_type == "baichuan":
|
||||||
|
# process NormHead module in Baichuan2 7B
|
||||||
|
if hasattr(model, 'lm_head') and model.lm_head is not None:
|
||||||
|
vocab_size, hidden_size = model.lm_head.weight.shape
|
||||||
|
lm_head_weight_data = model.lm_head.weight.data
|
||||||
|
model.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False,
|
||||||
|
device=lm_head_weight_data.device)
|
||||||
|
if model.lm_head.weight.data.device != "meta":
|
||||||
|
norm_weight = torch.nn.functional.normalize(lm_head_weight_data)
|
||||||
|
model.lm_head.weight.data = norm_weight
|
||||||
|
if model.config.hidden_size in [4096, 2048]:
|
||||||
|
from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
|
||||||
|
model.apply(pre_compute_inv_freq)
|
||||||
|
|
||||||
|
|
||||||
def optimize_llm(
|
def optimize_llm(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
max_output_len=1024,
|
max_output_len=1024,
|
||||||
|
|
@ -126,6 +142,10 @@ def optimize_llm(
|
||||||
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
|
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
|
||||||
elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32:
|
elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32:
|
||||||
# for Baichuan2-7B
|
# for Baichuan2-7B
|
||||||
|
if intra_pp is None:
|
||||||
|
intra_pp = 2
|
||||||
|
if inter_pp is None:
|
||||||
|
inter_pp = 2
|
||||||
from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward
|
from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward
|
||||||
from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner
|
from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner
|
||||||
decode_runner = DecodeRunner(
|
decode_runner = DecodeRunner(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue