Update npu baichuan2 (#11939)

This commit is contained in:
Zijie Li 2024-08-27 16:56:26 +08:00 committed by GitHub
parent 7f7f6c89f5
commit 90f692937d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 35 additions and 18 deletions

View file

@ -82,6 +82,7 @@ The example below shows how to run the **_optimized model implementations_** on
- [Llama3-8B](./llama.py)
- [Qwen2-1.5B](./qwen2.py)
- [MiniCPM-1B](./minicpm.py)
- [Baichuan2-7B](./baichuan2.py)
```bash
# to run Llama-2-7b-chat-hf
@ -95,6 +96,9 @@ python qwen2.py
# to run MiniCPM-1B-sft-bf16
python minicpm.py
# to run Baichuan2-7B-Chat
python baichuan2.py
```
Arguments info:

View file

@ -46,15 +46,15 @@ if __name__ == "__main__":
parser.add_argument(
"--repo-id-or-model-path",
type=str,
default="meta-llama/Llama-2-7b-chat-hf",
help="The huggingface repo id for the Llama2 model to be downloaded"
default="baichuan-inc/Baichuan2-7B-Chat",
help="The huggingface repo id for the Baichuan2 model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-output-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=768)
parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--intra-pp", type=int, default=2)
parser.add_argument("--inter-pp", type=int, default=2)
@ -68,7 +68,7 @@ if __name__ == "__main__":
trust_remote_code=True,
attn_implementation="eager",
load_in_low_bit="sym_int4",
enable_mp=True,
optimize_model=True,
max_output_len=args.max_output_len,
max_prompt_len=args.max_prompt_len,
intra_pp=args.intra_pp,

View file

@ -77,6 +77,8 @@ class _BaseAutoModelClass:
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'sym_int8'``,
``'fp16'``, ``'fp32'``.
Relevant low bit optimizations will be applied to the model.
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
Default to be ``False``.
:return: a model instance
"""
if kwargs.get("device_map", None) not in [None, "cpu", "auto"]:

View file

@ -272,7 +272,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
input_2d = self.convert_to_fp16(input_2d)
# attention
proj = self.linear(input_2d, 3 * self.hidden_size, self.hidden_size, bias=False, wt_dtype=self.dtype)
proj = self.linear(input_2d, 3 * self.hidden_size,
self.hidden_size, bias=False, wt_dtype=self.dtype)
# proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
proj = self.reshape(proj, [-1, 3, self.hidden_size]) # b*s, 3, h
proj = self.unsqueeze(proj, [0]) # b, s, 3, h
@ -282,13 +283,16 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
proj = self.unsqueeze(proj, [1])
print("proj shape after unsqueeze", proj.shape)
# query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = self.reshape(proj[0, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim])
query_states = self.reshape(proj[0, ...], [self.batch_size,
self.seq_len, self.num_heads, self.head_dim])
query_states = self.transpose(query_states, [0, 2, 1, 3])
# key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.reshape(proj[1, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim])
key_states = self.reshape(proj[1, ...], [self.batch_size,
self.seq_len, self.num_heads, self.head_dim])
key_states = self.transpose(key_states, [0, 2, 1, 3])
# value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.reshape(proj[2, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim])
value_states = self.reshape(proj[2, ...], [self.batch_size,
self.seq_len, self.num_heads, self.head_dim])
if self.transpose_value:
value_states = self.transpose(value_states, [0, 2, 3, 1])
else:
@ -309,7 +313,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
else:
value_states = self.concat(past_value, value_states, axis=-2)
attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim))
attn_weight = self.matmul(query_states, key_states, False, True) / (
math.sqrt(self.head_dim))
attn_weight = self.eltwise_add(attn_weight, attention_mask)
attn_weight = self.convert_to_fp32(attn_weight)
attn_weight = self.softmax(attn_weight, -1)
@ -349,7 +354,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
bias=False, wt_dtype=self.dtype) # type: ignore[attr-defined]
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
# down proj
hidden_states = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype)
hidden_states = self.linear(mm1, self.hidden_size,
self.intermediate_size, bias=False, wt_dtype=self.dtype)
hidden_states = self.eltwise_add(residual, hidden_states)
hidden_states = self.convert_to_fp16(hidden_states)
@ -1098,13 +1104,15 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
invalidInputError(False, "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
invalidInputError(False, "You cannot specify both decoder_input_ids\
and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
invalidInputError(False, "You have to specify either decoder_input_ids or decoder_inputs_embeds")
invalidInputError(False, "You have to specify either decoder_input_ids\
or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
@ -1120,7 +1128,8 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
past_key_values_length, seq_length + past_key_values_length,
dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
@ -1146,7 +1155,8 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing.\
Setting `use_cache=False`..."
)
use_cache = False

View file

@ -124,7 +124,8 @@ def optimize_llm(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
elif model.config.model_type == "baichuan":
elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32:
# for Baichuan2-7B
from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward
from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner
decode_runner = DecodeRunner(