Add support of llama3.2 for NPU C++ (#12442)
* initial support of llama3.2 * update * update * fix style * fix style * fix * small fix
This commit is contained in:
parent
cdd41f5e4c
commit
0e23bd779f
7 changed files with 133 additions and 27 deletions
|
|
@ -10,6 +10,7 @@ In this directory, you will find a C++ example on how to run LLM models on Intel
|
|||
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
|
||||
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
|
||||
| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16), [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) |
|
||||
| Llama3.2 | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct), [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |
|
||||
|
||||
## 0. Requirements
|
||||
To run this C++ example with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
|
||||
|
|
@ -55,6 +56,12 @@ python convert.py --repo-id-or-model-path openbmb/MiniCPM-1B-sft-bf16 --save-di
|
|||
|
||||
:: to convert MiniCPM-2B-sft-bf16
|
||||
python convert.py --repo-id-or-model-path openbmb/MiniCPM-2B-sft-bf16 --save-directory <converted_model_path>
|
||||
|
||||
:: to convert Llama-3.2-1B-Instruct
|
||||
python convert.py --repo-id-or-model-path meta-llama/Llama-3.2-1B-Instruct --save-directory <converted_model_path>
|
||||
|
||||
:: to convert Llama-3.2-3B-Instruct
|
||||
python convert.py --repo-id-or-model-path meta-llama/Llama-3.2-3B-Instruct --save-directory <converted_model_path>
|
||||
```
|
||||
|
||||
Arguments info:
|
||||
|
|
|
|||
|
|
@ -18,8 +18,12 @@
|
|||
import torch
|
||||
import argparse
|
||||
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
|
||||
import transformers
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.utils import logging
|
||||
from packaging import version
|
||||
import os
|
||||
import shutil
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
|
@ -67,7 +71,14 @@ if __name__ == "__main__":
|
|||
save_directory=save_dir)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
trans_version = transformers.__version__
|
||||
if version.parse(trans_version) >= version.parse("4.45.0"):
|
||||
tokenizer_json = os.path.join(model_path, "tokenizer.json")
|
||||
dst_path = os.path.join(save_dir, "tokenizer.json")
|
||||
shutil.copy(tokenizer_json, dst_path)
|
||||
else:
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
print("-" * 80)
|
||||
print(f"finish save model to {save_dir}")
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
|
|||
for (int i = 1; i < params.n_predict; i++){
|
||||
auto logits = run_decode(model, embd[i-1]);
|
||||
int32_t token = llm_sample_token(logits, true, model_params);
|
||||
if (token != tok_params.eos_token_id) {
|
||||
if (std::find(tok_params.eos_token_id.begin(), tok_params.eos_token_id.end(), token) == tok_params.eos_token_id.end()){
|
||||
embd.push_back(token);
|
||||
token_nums ++;
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ if __name__ == "__main__":
|
|||
help='Prompt to infer')
|
||||
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
|
||||
parser.add_argument("--max-context-len", type=int, default=1024)
|
||||
parser.add_argument("--max-prompt-len", type=int, default=960)
|
||||
parser.add_argument("--max-prompt-len", type=int, default=512)
|
||||
parser.add_argument("--quantization_group_size", type=int, default=0)
|
||||
parser.add_argument('--low_bit', type=str, default="sym_int4",
|
||||
help='Low bit precision to quantize the model')
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
|||
n_splits_down_proj: int = 1,
|
||||
group_size: int = 0,
|
||||
cos_len: int = 1,
|
||||
keep_position_ids=True,
|
||||
):
|
||||
super().__init__(max_seq_len=max_seq_len,
|
||||
transpose_value=transpose_value,
|
||||
|
|
@ -122,7 +123,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
|||
self.seq_len),
|
||||
dtype=np.float16)
|
||||
if self.cached_cos is None:
|
||||
if mode == "prefill":
|
||||
if mode == "prefill" and keep_position_ids:
|
||||
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
|
||||
cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim),
|
||||
dtype=np.float32)
|
||||
|
|
@ -185,12 +186,12 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
|||
hidden_states = input
|
||||
|
||||
curr_key_values = []
|
||||
cos_condition = cached_cos is not None or (mode == "prefill" and keep_position_ids)
|
||||
for i in range(num_layers):
|
||||
hidden_states, new_key_states, new_value_states = self.build_decoder(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids if (cached_cos is not None
|
||||
or mode == "prefill") else None,
|
||||
position_ids=position_ids if cos_condition else None,
|
||||
input_layernorm_weight=input_layernorm_weights[i],
|
||||
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
|
||||
past_key=past_keys[i],
|
||||
|
|
|
|||
|
|
@ -456,15 +456,27 @@ def convert_llm_for_deploy(model: torch.nn.Module,
|
|||
group_size, layernorm_const, "prefill")
|
||||
# save blob of lmhead and bin of embedding
|
||||
convert_lm_head_and_embedding(model, n_splits_linear,
|
||||
save_directory, weight_dir, True)
|
||||
save_directory, weight_dir,
|
||||
convert_model=True)
|
||||
elif model.config.model_type == "llama":
|
||||
layernorm_const = True
|
||||
embedding_post = False
|
||||
cos_sin_input = False
|
||||
use_prefill_sdp = False
|
||||
if model.config.vocab_size == 32000:
|
||||
# for Llama2-7B
|
||||
fused_layers = 4
|
||||
use_prefill_sdp = True
|
||||
else:
|
||||
# for Llama3-8B
|
||||
fused_layers = 2
|
||||
if model.config.intermediate_size == 8192:
|
||||
# llama3.2 1B & # llama3.2 3B
|
||||
embedding_post = True
|
||||
cos_sin_input = True
|
||||
fused_layers = 2
|
||||
else:
|
||||
# for Llama3-8B
|
||||
fused_layers = 2
|
||||
use_prefill_sdp = True
|
||||
update_dict = {"kv_len": kv_len,
|
||||
"num_head": model.model.layers[0].self_attn.num_heads,
|
||||
"head_dim": model.model.layers[0].self_attn.head_dim,
|
||||
|
|
@ -474,14 +486,21 @@ def convert_llm_for_deploy(model: torch.nn.Module,
|
|||
"group_size": group_size,
|
||||
"fused_layers": fused_layers,
|
||||
"qkv_bias": False,
|
||||
"use_prefill_sdp": True,
|
||||
"use_prefill_sdp": use_prefill_sdp,
|
||||
"weight_num": 7,
|
||||
"weight_idx": 5}
|
||||
"weight_idx": 5,
|
||||
"embedding_post": embedding_post,
|
||||
"cos_sin_input": cos_sin_input}
|
||||
model.config.update(update_dict)
|
||||
model.config.save_pretrained(save_directory)
|
||||
|
||||
from .llama import convert_llama_layer, convert_fused_llama_layer
|
||||
from .llama import convert_lm_head_and_embedding
|
||||
# save blob of lmhead and bin of embedding & (optional) embedding_post
|
||||
convert_lm_head_and_embedding(model, n_splits_linear,
|
||||
save_directory, weight_dir,
|
||||
convert_model=True,
|
||||
max_prompt_len=max_prompt_len)
|
||||
# save fused_layers blobs of fused decoder layers
|
||||
convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
|
||||
save_directory, weight_dir, transpose_value_cache, kv_len,
|
||||
|
|
@ -490,9 +509,6 @@ def convert_llm_for_deploy(model: torch.nn.Module,
|
|||
convert_llama_layer(model, 0, n_splits_linear, n_splits_down_proj,
|
||||
save_directory, weight_dir, transpose_value_cache, max_prompt_len,
|
||||
group_size, layernorm_const, "prefill")
|
||||
# save blob of lmhead and bin of embedding
|
||||
convert_lm_head_and_embedding(model, n_splits_linear,
|
||||
save_directory, weight_dir, True)
|
||||
elif model.config.model_type == "minicpm":
|
||||
layernorm_const = True
|
||||
fused_layers = 4
|
||||
|
|
@ -523,6 +539,8 @@ def convert_llm_for_deploy(model: torch.nn.Module,
|
|||
convert_minicpm_layer(model, 0, n_splits_linear, n_splits_down_proj,
|
||||
save_directory, weight_dir, transpose_value_cache, max_prompt_len,
|
||||
group_size, layernorm_const, "prefill")
|
||||
# save blob of lmhead and bin of embedding
|
||||
# save blob of lmhead and bin of embedding and embedding_post
|
||||
convert_lm_head_and_embedding(model, n_splits_linear,
|
||||
save_directory, weight_dir, True, max_prompt_len)
|
||||
save_directory, weight_dir,
|
||||
convert_model=True,
|
||||
max_prompt_len=max_prompt_len)
|
||||
|
|
|
|||
|
|
@ -83,8 +83,46 @@ class Llama32Embedding(NNFactory):
|
|||
self.compile()
|
||||
|
||||
|
||||
class Llama32PostEmbedding(NNFactory):
|
||||
def __init__(
|
||||
self,
|
||||
inv_freq,
|
||||
attention_scaling,
|
||||
input_len: int = 1,
|
||||
device: str = "NPU",
|
||||
):
|
||||
super().__init__(False, device)
|
||||
self.attention_scaling = attention_scaling
|
||||
|
||||
# define input
|
||||
position_ids = self.parameter((1, input_len), dtype=np.int64)
|
||||
inv_freq = self.constant(inv_freq)
|
||||
|
||||
# rotary_emb module
|
||||
inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1))
|
||||
position_ids = self.reshape(position_ids, (1, 1, input_len))
|
||||
freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq),
|
||||
self.convert_to_fp32(position_ids))
|
||||
freqs = self.transpose(freqs, [0, 2, 1])
|
||||
emb = self.concat(freqs, freqs, axis=2)
|
||||
cos = self.cos(emb)
|
||||
sin = self.sin(emb)
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
if input_len > 1:
|
||||
cos = self.unsqueeze(cos, [1])
|
||||
sin = self.unsqueeze(sin, [1])
|
||||
|
||||
# define outputs
|
||||
cos = self.convert_to_fp32(cos)
|
||||
sin = self.convert_to_fp32(sin)
|
||||
|
||||
print("start compiling")
|
||||
self.compile()
|
||||
|
||||
|
||||
def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
|
||||
convert_model=False):
|
||||
convert_model=False, max_prompt_len=1):
|
||||
num_heads = model.model.layers[0].self_attn.num_heads
|
||||
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
|
||||
head_dim = model.model.layers[0].self_attn.head_dim
|
||||
|
|
@ -145,6 +183,13 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
|
|||
padding_idx=model.config.pad_token_id,
|
||||
dtype=np.float16,
|
||||
)
|
||||
if convert_model:
|
||||
bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
|
||||
embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
|
||||
first_blob_path = None
|
||||
else:
|
||||
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
|
||||
temp_dir, True, False)
|
||||
else:
|
||||
# llama-3.2-3B & llama-3.2-1B
|
||||
embedding_layer = model.model.embed_tokens
|
||||
|
|
@ -157,13 +202,27 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
|
|||
attention_scaling=model.model.rotary_emb.attention_scaling,
|
||||
dtype=np.float16,
|
||||
)
|
||||
if convert_model:
|
||||
bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
|
||||
embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
|
||||
first_blob_path = None
|
||||
else:
|
||||
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
|
||||
temp_dir)
|
||||
if convert_model:
|
||||
bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
|
||||
embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
|
||||
first_blob_path = None
|
||||
# save embedding post module
|
||||
inv_freq = model.model.rotary_emb.inv_freq.to(torch.float16)
|
||||
attention_scaling = model.model.rotary_emb.attention_scaling
|
||||
embedding_post = Llama32PostEmbedding(inv_freq=inv_freq,
|
||||
attention_scaling=attention_scaling,
|
||||
input_len=1)
|
||||
update_names_of_IR_and_export_blob(embedding_post, "embedding_post",
|
||||
temp_dir, True, False)
|
||||
embedding_post_prefill = Llama32PostEmbedding(inv_freq=inv_freq,
|
||||
attention_scaling=attention_scaling,
|
||||
input_len=max_prompt_len)
|
||||
update_names_of_IR_and_export_blob(embedding_post_prefill,
|
||||
"embedding_post_prefill",
|
||||
temp_dir, True, False)
|
||||
else:
|
||||
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
|
||||
temp_dir)
|
||||
return first_blob_path, last_blob_path
|
||||
|
||||
|
||||
|
|
@ -212,10 +271,12 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
|||
if mode == "decode":
|
||||
input_len = 1
|
||||
decoder_name = f"decoder_layer_{layer_idx}"
|
||||
keep_position_ids = True
|
||||
else:
|
||||
input_len = kv_len
|
||||
decoder_name = "decoder_layer_prefill"
|
||||
layernorm_const = False
|
||||
keep_position_ids = False
|
||||
|
||||
single_decoder = LowBitLlamaMultiDecoderlayer(
|
||||
[1, input_len, num_heads * head_dim],
|
||||
|
|
@ -234,7 +295,9 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
|||
dtype=np_dtype,
|
||||
n_splits_linear=n_splits_linear,
|
||||
n_splits_down_proj=n_splits_down_proj,
|
||||
group_size=group_size
|
||||
group_size=group_size,
|
||||
cos_len=input_len,
|
||||
keep_position_ids=keep_position_ids
|
||||
)
|
||||
|
||||
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
|
||||
|
|
@ -309,8 +372,14 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow
|
|||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
|
||||
# llama-2-7B & llama-3-8B
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
else:
|
||||
# llama-3.2-3B & llama-3.2-1B
|
||||
cached_cos = None
|
||||
cached_sin = None
|
||||
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue