Update npu baichuan2 (#11939)
This commit is contained in:
parent
7f7f6c89f5
commit
90f692937d
5 changed files with 35 additions and 18 deletions
|
|
@ -82,6 +82,7 @@ The example below shows how to run the **_optimized model implementations_** on
|
||||||
- [Llama3-8B](./llama.py)
|
- [Llama3-8B](./llama.py)
|
||||||
- [Qwen2-1.5B](./qwen2.py)
|
- [Qwen2-1.5B](./qwen2.py)
|
||||||
- [MiniCPM-1B](./minicpm.py)
|
- [MiniCPM-1B](./minicpm.py)
|
||||||
|
- [Baichuan2-7B](./baichuan2.py)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# to run Llama-2-7b-chat-hf
|
# to run Llama-2-7b-chat-hf
|
||||||
|
|
@ -95,6 +96,9 @@ python qwen2.py
|
||||||
|
|
||||||
# to run MiniCPM-1B-sft-bf16
|
# to run MiniCPM-1B-sft-bf16
|
||||||
python minicpm.py
|
python minicpm.py
|
||||||
|
|
||||||
|
# to run Baichuan2-7B-Chat
|
||||||
|
python baichuan2.py
|
||||||
```
|
```
|
||||||
|
|
||||||
Arguments info:
|
Arguments info:
|
||||||
|
|
|
||||||
|
|
@ -46,15 +46,15 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--repo-id-or-model-path",
|
"--repo-id-or-model-path",
|
||||||
type=str,
|
type=str,
|
||||||
default="meta-llama/Llama-2-7b-chat-hf",
|
default="baichuan-inc/Baichuan2-7B-Chat",
|
||||||
help="The huggingface repo id for the Llama2 model to be downloaded"
|
help="The huggingface repo id for the Baichuan2 model to be downloaded"
|
||||||
", or the path to the huggingface checkpoint folder",
|
", or the path to the huggingface checkpoint folder",
|
||||||
)
|
)
|
||||||
parser.add_argument('--prompt', type=str, default="What is AI?",
|
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||||
help='Prompt to infer')
|
help='Prompt to infer')
|
||||||
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
|
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
|
||||||
parser.add_argument("--max-output-len", type=int, default=1024)
|
parser.add_argument("--max-output-len", type=int, default=1024)
|
||||||
parser.add_argument("--max-prompt-len", type=int, default=768)
|
parser.add_argument("--max-prompt-len", type=int, default=512)
|
||||||
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
|
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
|
||||||
parser.add_argument("--intra-pp", type=int, default=2)
|
parser.add_argument("--intra-pp", type=int, default=2)
|
||||||
parser.add_argument("--inter-pp", type=int, default=2)
|
parser.add_argument("--inter-pp", type=int, default=2)
|
||||||
|
|
@ -68,7 +68,7 @@ if __name__ == "__main__":
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
attn_implementation="eager",
|
attn_implementation="eager",
|
||||||
load_in_low_bit="sym_int4",
|
load_in_low_bit="sym_int4",
|
||||||
enable_mp=True,
|
optimize_model=True,
|
||||||
max_output_len=args.max_output_len,
|
max_output_len=args.max_output_len,
|
||||||
max_prompt_len=args.max_prompt_len,
|
max_prompt_len=args.max_prompt_len,
|
||||||
intra_pp=args.intra_pp,
|
intra_pp=args.intra_pp,
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,8 @@ class _BaseAutoModelClass:
|
||||||
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'sym_int8'``,
|
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'sym_int8'``,
|
||||||
``'fp16'``, ``'fp32'``.
|
``'fp16'``, ``'fp32'``.
|
||||||
Relevant low bit optimizations will be applied to the model.
|
Relevant low bit optimizations will be applied to the model.
|
||||||
|
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
|
||||||
|
Default to be ``False``.
|
||||||
:return: a model instance
|
:return: a model instance
|
||||||
"""
|
"""
|
||||||
if kwargs.get("device_map", None) not in [None, "cpu", "auto"]:
|
if kwargs.get("device_map", None) not in [None, "cpu", "auto"]:
|
||||||
|
|
|
||||||
|
|
@ -272,7 +272,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
|
||||||
input_2d = self.convert_to_fp16(input_2d)
|
input_2d = self.convert_to_fp16(input_2d)
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
proj = self.linear(input_2d, 3 * self.hidden_size, self.hidden_size, bias=False, wt_dtype=self.dtype)
|
proj = self.linear(input_2d, 3 * self.hidden_size,
|
||||||
|
self.hidden_size, bias=False, wt_dtype=self.dtype)
|
||||||
# proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
# proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
||||||
proj = self.reshape(proj, [-1, 3, self.hidden_size]) # b*s, 3, h
|
proj = self.reshape(proj, [-1, 3, self.hidden_size]) # b*s, 3, h
|
||||||
proj = self.unsqueeze(proj, [0]) # b, s, 3, h
|
proj = self.unsqueeze(proj, [0]) # b, s, 3, h
|
||||||
|
|
@ -282,13 +283,16 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
|
||||||
proj = self.unsqueeze(proj, [1])
|
proj = self.unsqueeze(proj, [1])
|
||||||
print("proj shape after unsqueeze", proj.shape)
|
print("proj shape after unsqueeze", proj.shape)
|
||||||
# query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
# query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
query_states = self.reshape(proj[0, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim])
|
query_states = self.reshape(proj[0, ...], [self.batch_size,
|
||||||
|
self.seq_len, self.num_heads, self.head_dim])
|
||||||
query_states = self.transpose(query_states, [0, 2, 1, 3])
|
query_states = self.transpose(query_states, [0, 2, 1, 3])
|
||||||
# key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
# key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = self.reshape(proj[1, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim])
|
key_states = self.reshape(proj[1, ...], [self.batch_size,
|
||||||
|
self.seq_len, self.num_heads, self.head_dim])
|
||||||
key_states = self.transpose(key_states, [0, 2, 1, 3])
|
key_states = self.transpose(key_states, [0, 2, 1, 3])
|
||||||
# value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
# value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = self.reshape(proj[2, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim])
|
value_states = self.reshape(proj[2, ...], [self.batch_size,
|
||||||
|
self.seq_len, self.num_heads, self.head_dim])
|
||||||
if self.transpose_value:
|
if self.transpose_value:
|
||||||
value_states = self.transpose(value_states, [0, 2, 3, 1])
|
value_states = self.transpose(value_states, [0, 2, 3, 1])
|
||||||
else:
|
else:
|
||||||
|
|
@ -309,7 +313,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
|
||||||
else:
|
else:
|
||||||
value_states = self.concat(past_value, value_states, axis=-2)
|
value_states = self.concat(past_value, value_states, axis=-2)
|
||||||
|
|
||||||
attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim))
|
attn_weight = self.matmul(query_states, key_states, False, True) / (
|
||||||
|
math.sqrt(self.head_dim))
|
||||||
attn_weight = self.eltwise_add(attn_weight, attention_mask)
|
attn_weight = self.eltwise_add(attn_weight, attention_mask)
|
||||||
attn_weight = self.convert_to_fp32(attn_weight)
|
attn_weight = self.convert_to_fp32(attn_weight)
|
||||||
attn_weight = self.softmax(attn_weight, -1)
|
attn_weight = self.softmax(attn_weight, -1)
|
||||||
|
|
@ -349,7 +354,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
|
||||||
bias=False, wt_dtype=self.dtype) # type: ignore[attr-defined]
|
bias=False, wt_dtype=self.dtype) # type: ignore[attr-defined]
|
||||||
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
|
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
|
||||||
# down proj
|
# down proj
|
||||||
hidden_states = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype)
|
hidden_states = self.linear(mm1, self.hidden_size,
|
||||||
|
self.intermediate_size, bias=False, wt_dtype=self.dtype)
|
||||||
|
|
||||||
hidden_states = self.eltwise_add(residual, hidden_states)
|
hidden_states = self.eltwise_add(residual, hidden_states)
|
||||||
hidden_states = self.convert_to_fp16(hidden_states)
|
hidden_states = self.convert_to_fp16(hidden_states)
|
||||||
|
|
@ -374,7 +380,7 @@ class LowBitLlamaMultiDecoderlayer(NNFactory):
|
||||||
cos = self.squeeze(cos) # [seq_len, dim]
|
cos = self.squeeze(cos) # [seq_len, dim]
|
||||||
sin = self.squeeze(sin) # [seq_len, dim]
|
sin = self.squeeze(sin) # [seq_len, dim]
|
||||||
# cos = cos[position_ids]
|
# cos = cos[position_ids]
|
||||||
cos = self.unsqueeze(cos, [0, 1]) # [bs, 1, seq_len, dim]
|
cos = self.unsqueeze(cos, [0, 1]) # [bs, 1, seq_len, dim]
|
||||||
# sin = sin[position_ids]
|
# sin = sin[position_ids]
|
||||||
sin = self.unsqueeze(sin, [0, 1]) # [bs, 1, seq_len, dim]
|
sin = self.unsqueeze(sin, [0, 1]) # [bs, 1, seq_len, dim]
|
||||||
|
|
||||||
|
|
@ -1098,13 +1104,15 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
invalidInputError(False, "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
invalidInputError(False, "You cannot specify both decoder_input_ids\
|
||||||
|
and decoder_inputs_embeds at the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
else:
|
else:
|
||||||
invalidInputError(False, "You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
invalidInputError(False, "You have to specify either decoder_input_ids\
|
||||||
|
or decoder_inputs_embeds")
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
seq_length_with_past = seq_length
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
|
@ -1120,7 +1128,8 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(
|
||||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
past_key_values_length, seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long, device=device
|
||||||
)
|
)
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
else:
|
else:
|
||||||
|
|
@ -1146,7 +1155,8 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
"`use_cache=True` is incompatible with gradient checkpointing.\
|
||||||
|
Setting `use_cache=False`..."
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,8 @@ def optimize_llm(
|
||||||
prefill_runner=prefill_runner, decode_runner=decode_runner
|
prefill_runner=prefill_runner, decode_runner=decode_runner
|
||||||
)
|
)
|
||||||
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
|
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
|
||||||
elif model.config.model_type == "baichuan":
|
elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32:
|
||||||
|
# for Baichuan2-7B
|
||||||
from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward
|
from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward
|
||||||
from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner
|
from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner
|
||||||
decode_runner = DecodeRunner(
|
decode_runner = DecodeRunner(
|
||||||
|
|
@ -141,8 +142,8 @@ def optimize_llm(
|
||||||
transpose_value_cache=transpose_value_cache,
|
transpose_value_cache=transpose_value_cache,
|
||||||
)
|
)
|
||||||
baichuan_model_forward = gen_baichuan_fused_model_forward(
|
baichuan_model_forward = gen_baichuan_fused_model_forward(
|
||||||
prefill_runner=prefill_runner, decode_runner=decode_runner
|
prefill_runner=prefill_runner, decode_runner=decode_runner
|
||||||
)
|
)
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
convert_forward(model, module.BaichuanModel, baichuan_model_forward)
|
convert_forward(model, module.BaichuanModel, baichuan_model_forward)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue