Update npu baichuan2 (#11939)

This commit is contained in:
Zijie Li 2024-08-27 16:56:26 +08:00 committed by GitHub
parent 7f7f6c89f5
commit 90f692937d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 35 additions and 18 deletions

View file

@ -82,6 +82,7 @@ The example below shows how to run the **_optimized model implementations_** on
- [Llama3-8B](./llama.py) - [Llama3-8B](./llama.py)
- [Qwen2-1.5B](./qwen2.py) - [Qwen2-1.5B](./qwen2.py)
- [MiniCPM-1B](./minicpm.py) - [MiniCPM-1B](./minicpm.py)
- [Baichuan2-7B](./baichuan2.py)
```bash ```bash
# to run Llama-2-7b-chat-hf # to run Llama-2-7b-chat-hf
@ -95,6 +96,9 @@ python qwen2.py
# to run MiniCPM-1B-sft-bf16 # to run MiniCPM-1B-sft-bf16
python minicpm.py python minicpm.py
# to run Baichuan2-7B-Chat
python baichuan2.py
``` ```
Arguments info: Arguments info:

View file

@ -46,15 +46,15 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--repo-id-or-model-path", "--repo-id-or-model-path",
type=str, type=str,
default="meta-llama/Llama-2-7b-chat-hf", default="baichuan-inc/Baichuan2-7B-Chat",
help="The huggingface repo id for the Llama2 model to be downloaded" help="The huggingface repo id for the Baichuan2 model to be downloaded"
", or the path to the huggingface checkpoint folder", ", or the path to the huggingface checkpoint folder",
) )
parser.add_argument('--prompt', type=str, default="What is AI?", parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer') help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-output-len", type=int, default=1024) parser.add_argument("--max-output-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=768) parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--intra-pp", type=int, default=2) parser.add_argument("--intra-pp", type=int, default=2)
parser.add_argument("--inter-pp", type=int, default=2) parser.add_argument("--inter-pp", type=int, default=2)
@ -68,7 +68,7 @@ if __name__ == "__main__":
trust_remote_code=True, trust_remote_code=True,
attn_implementation="eager", attn_implementation="eager",
load_in_low_bit="sym_int4", load_in_low_bit="sym_int4",
enable_mp=True, optimize_model=True,
max_output_len=args.max_output_len, max_output_len=args.max_output_len,
max_prompt_len=args.max_prompt_len, max_prompt_len=args.max_prompt_len,
intra_pp=args.intra_pp, intra_pp=args.intra_pp,

View file

@ -77,6 +77,8 @@ class _BaseAutoModelClass:
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'sym_int8'``, :param load_in_low_bit: str value, options are ``'sym_int4'``, ``'sym_int8'``,
``'fp16'``, ``'fp32'``. ``'fp16'``, ``'fp32'``.
Relevant low bit optimizations will be applied to the model. Relevant low bit optimizations will be applied to the model.
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
Default to be ``False``.
:return: a model instance :return: a model instance
""" """
if kwargs.get("device_map", None) not in [None, "cpu", "auto"]: if kwargs.get("device_map", None) not in [None, "cpu", "auto"]:

View file

@ -272,7 +272,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
input_2d = self.convert_to_fp16(input_2d) input_2d = self.convert_to_fp16(input_2d)
# attention # attention
proj = self.linear(input_2d, 3 * self.hidden_size, self.hidden_size, bias=False, wt_dtype=self.dtype) proj = self.linear(input_2d, 3 * self.hidden_size,
self.hidden_size, bias=False, wt_dtype=self.dtype)
# proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) # proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
proj = self.reshape(proj, [-1, 3, self.hidden_size]) # b*s, 3, h proj = self.reshape(proj, [-1, 3, self.hidden_size]) # b*s, 3, h
proj = self.unsqueeze(proj, [0]) # b, s, 3, h proj = self.unsqueeze(proj, [0]) # b, s, 3, h
@ -282,13 +283,16 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
proj = self.unsqueeze(proj, [1]) proj = self.unsqueeze(proj, [1])
print("proj shape after unsqueeze", proj.shape) print("proj shape after unsqueeze", proj.shape)
# query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = self.reshape(proj[0, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) query_states = self.reshape(proj[0, ...], [self.batch_size,
self.seq_len, self.num_heads, self.head_dim])
query_states = self.transpose(query_states, [0, 2, 1, 3]) query_states = self.transpose(query_states, [0, 2, 1, 3])
# key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.reshape(proj[1, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) key_states = self.reshape(proj[1, ...], [self.batch_size,
self.seq_len, self.num_heads, self.head_dim])
key_states = self.transpose(key_states, [0, 2, 1, 3]) key_states = self.transpose(key_states, [0, 2, 1, 3])
# value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.reshape(proj[2, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) value_states = self.reshape(proj[2, ...], [self.batch_size,
self.seq_len, self.num_heads, self.head_dim])
if self.transpose_value: if self.transpose_value:
value_states = self.transpose(value_states, [0, 2, 3, 1]) value_states = self.transpose(value_states, [0, 2, 3, 1])
else: else:
@ -309,7 +313,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
else: else:
value_states = self.concat(past_value, value_states, axis=-2) value_states = self.concat(past_value, value_states, axis=-2)
attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim)) attn_weight = self.matmul(query_states, key_states, False, True) / (
math.sqrt(self.head_dim))
attn_weight = self.eltwise_add(attn_weight, attention_mask) attn_weight = self.eltwise_add(attn_weight, attention_mask)
attn_weight = self.convert_to_fp32(attn_weight) attn_weight = self.convert_to_fp32(attn_weight)
attn_weight = self.softmax(attn_weight, -1) attn_weight = self.softmax(attn_weight, -1)
@ -349,7 +354,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
bias=False, wt_dtype=self.dtype) # type: ignore[attr-defined] bias=False, wt_dtype=self.dtype) # type: ignore[attr-defined]
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
# down proj # down proj
hidden_states = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype) hidden_states = self.linear(mm1, self.hidden_size,
self.intermediate_size, bias=False, wt_dtype=self.dtype)
hidden_states = self.eltwise_add(residual, hidden_states) hidden_states = self.eltwise_add(residual, hidden_states)
hidden_states = self.convert_to_fp16(hidden_states) hidden_states = self.convert_to_fp16(hidden_states)
@ -374,7 +380,7 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
cos = self.squeeze(cos) # [seq_len, dim] cos = self.squeeze(cos) # [seq_len, dim]
sin = self.squeeze(sin) # [seq_len, dim] sin = self.squeeze(sin) # [seq_len, dim]
# cos = cos[position_ids] # cos = cos[position_ids]
cos = self.unsqueeze(cos, [0, 1]) # [bs, 1, seq_len, dim] cos = self.unsqueeze(cos, [0, 1]) # [bs, 1, seq_len, dim]
# sin = sin[position_ids] # sin = sin[position_ids]
sin = self.unsqueeze(sin, [0, 1]) # [bs, 1, seq_len, dim] sin = self.unsqueeze(sin, [0, 1]) # [bs, 1, seq_len, dim]
@ -1098,13 +1104,15 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
invalidInputError(False, "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") invalidInputError(False, "You cannot specify both decoder_input_ids\
and decoder_inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None: elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape batch_size, seq_length, _ = inputs_embeds.shape
else: else:
invalidInputError(False, "You have to specify either decoder_input_ids or decoder_inputs_embeds") invalidInputError(False, "You have to specify either decoder_input_ids\
or decoder_inputs_embeds")
seq_length_with_past = seq_length seq_length_with_past = seq_length
past_key_values_length = 0 past_key_values_length = 0
@ -1120,7 +1128,8 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange( position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device past_key_values_length, seq_length + past_key_values_length,
dtype=torch.long, device=device
) )
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else: else:
@ -1146,7 +1155,8 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning_once( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing.\
Setting `use_cache=False`..."
) )
use_cache = False use_cache = False

View file

@ -124,7 +124,8 @@ def optimize_llm(
prefill_runner=prefill_runner, decode_runner=decode_runner prefill_runner=prefill_runner, decode_runner=decode_runner
) )
convert_forward(model, module.MiniCPMModel, minicpm_model_forward) convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
elif model.config.model_type == "baichuan": elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32:
# for Baichuan2-7B
from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward
from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner
decode_runner = DecodeRunner( decode_runner = DecodeRunner(
@ -141,8 +142,8 @@ def optimize_llm(
transpose_value_cache=transpose_value_cache, transpose_value_cache=transpose_value_cache,
) )
baichuan_model_forward = gen_baichuan_fused_model_forward( baichuan_model_forward = gen_baichuan_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner prefill_runner=prefill_runner, decode_runner=decode_runner
) )
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
convert_forward(model, module.BaichuanModel, baichuan_model_forward) convert_forward(model, module.BaichuanModel, baichuan_model_forward)