Update npu baichuan2 (#11939)
This commit is contained in:
parent
7f7f6c89f5
commit
90f692937d
5 changed files with 35 additions and 18 deletions
|
|
@ -82,6 +82,7 @@ The example below shows how to run the **_optimized model implementations_** on
|
|||
- [Llama3-8B](./llama.py)
|
||||
- [Qwen2-1.5B](./qwen2.py)
|
||||
- [MiniCPM-1B](./minicpm.py)
|
||||
- [Baichuan2-7B](./baichuan2.py)
|
||||
|
||||
```bash
|
||||
# to run Llama-2-7b-chat-hf
|
||||
|
|
@ -95,6 +96,9 @@ python qwen2.py
|
|||
|
||||
# to run MiniCPM-1B-sft-bf16
|
||||
python minicpm.py
|
||||
|
||||
# to run Baichuan2-7B-Chat
|
||||
python baichuan2.py
|
||||
```
|
||||
|
||||
Arguments info:
|
||||
|
|
|
|||
|
|
@ -46,15 +46,15 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"--repo-id-or-model-path",
|
||||
type=str,
|
||||
default="meta-llama/Llama-2-7b-chat-hf",
|
||||
help="The huggingface repo id for the Llama2 model to be downloaded"
|
||||
default="baichuan-inc/Baichuan2-7B-Chat",
|
||||
help="The huggingface repo id for the Baichuan2 model to be downloaded"
|
||||
", or the path to the huggingface checkpoint folder",
|
||||
)
|
||||
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||
help='Prompt to infer')
|
||||
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
|
||||
parser.add_argument("--max-output-len", type=int, default=1024)
|
||||
parser.add_argument("--max-prompt-len", type=int, default=768)
|
||||
parser.add_argument("--max-prompt-len", type=int, default=512)
|
||||
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
|
||||
parser.add_argument("--intra-pp", type=int, default=2)
|
||||
parser.add_argument("--inter-pp", type=int, default=2)
|
||||
|
|
@ -68,7 +68,7 @@ if __name__ == "__main__":
|
|||
trust_remote_code=True,
|
||||
attn_implementation="eager",
|
||||
load_in_low_bit="sym_int4",
|
||||
enable_mp=True,
|
||||
optimize_model=True,
|
||||
max_output_len=args.max_output_len,
|
||||
max_prompt_len=args.max_prompt_len,
|
||||
intra_pp=args.intra_pp,
|
||||
|
|
|
|||
|
|
@ -77,6 +77,8 @@ class _BaseAutoModelClass:
|
|||
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'sym_int8'``,
|
||||
``'fp16'``, ``'fp32'``.
|
||||
Relevant low bit optimizations will be applied to the model.
|
||||
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
|
||||
Default to be ``False``.
|
||||
:return: a model instance
|
||||
"""
|
||||
if kwargs.get("device_map", None) not in [None, "cpu", "auto"]:
|
||||
|
|
|
|||
|
|
@ -272,7 +272,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
|
|||
input_2d = self.convert_to_fp16(input_2d)
|
||||
|
||||
# attention
|
||||
proj = self.linear(input_2d, 3 * self.hidden_size, self.hidden_size, bias=False, wt_dtype=self.dtype)
|
||||
proj = self.linear(input_2d, 3 * self.hidden_size,
|
||||
self.hidden_size, bias=False, wt_dtype=self.dtype)
|
||||
# proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
||||
proj = self.reshape(proj, [-1, 3, self.hidden_size]) # b*s, 3, h
|
||||
proj = self.unsqueeze(proj, [0]) # b, s, 3, h
|
||||
|
|
@ -282,13 +283,16 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
|
|||
proj = self.unsqueeze(proj, [1])
|
||||
print("proj shape after unsqueeze", proj.shape)
|
||||
# query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
query_states = self.reshape(proj[0, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim])
|
||||
query_states = self.reshape(proj[0, ...], [self.batch_size,
|
||||
self.seq_len, self.num_heads, self.head_dim])
|
||||
query_states = self.transpose(query_states, [0, 2, 1, 3])
|
||||
# key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = self.reshape(proj[1, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim])
|
||||
key_states = self.reshape(proj[1, ...], [self.batch_size,
|
||||
self.seq_len, self.num_heads, self.head_dim])
|
||||
key_states = self.transpose(key_states, [0, 2, 1, 3])
|
||||
# value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.reshape(proj[2, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim])
|
||||
value_states = self.reshape(proj[2, ...], [self.batch_size,
|
||||
self.seq_len, self.num_heads, self.head_dim])
|
||||
if self.transpose_value:
|
||||
value_states = self.transpose(value_states, [0, 2, 3, 1])
|
||||
else:
|
||||
|
|
@ -309,7 +313,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
|
|||
else:
|
||||
value_states = self.concat(past_value, value_states, axis=-2)
|
||||
|
||||
attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim))
|
||||
attn_weight = self.matmul(query_states, key_states, False, True) / (
|
||||
math.sqrt(self.head_dim))
|
||||
attn_weight = self.eltwise_add(attn_weight, attention_mask)
|
||||
attn_weight = self.convert_to_fp32(attn_weight)
|
||||
attn_weight = self.softmax(attn_weight, -1)
|
||||
|
|
@ -349,7 +354,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
|
|||
bias=False, wt_dtype=self.dtype) # type: ignore[attr-defined]
|
||||
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
|
||||
# down proj
|
||||
hidden_states = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype)
|
||||
hidden_states = self.linear(mm1, self.hidden_size,
|
||||
self.intermediate_size, bias=False, wt_dtype=self.dtype)
|
||||
|
||||
hidden_states = self.eltwise_add(residual, hidden_states)
|
||||
hidden_states = self.convert_to_fp16(hidden_states)
|
||||
|
|
@ -1098,13 +1104,15 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
|
|||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
invalidInputError(False, "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
invalidInputError(False, "You cannot specify both decoder_input_ids\
|
||||
and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
invalidInputError(False, "You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
invalidInputError(False, "You have to specify either decoder_input_ids\
|
||||
or decoder_inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
|
@ -1120,7 +1128,8 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
|
|||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
past_key_values_length, seq_length + past_key_values_length,
|
||||
dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
|
|
@ -1146,7 +1155,8 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
|
|||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
"`use_cache=True` is incompatible with gradient checkpointing.\
|
||||
Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
|
|
|
|||
|
|
@ -124,7 +124,8 @@ def optimize_llm(
|
|||
prefill_runner=prefill_runner, decode_runner=decode_runner
|
||||
)
|
||||
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
|
||||
elif model.config.model_type == "baichuan":
|
||||
elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32:
|
||||
# for Baichuan2-7B
|
||||
from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward
|
||||
from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner
|
||||
decode_runner = DecodeRunner(
|
||||
|
|
|
|||
Loading…
Reference in a new issue