Add support of llama3.2 for NPU C++ (#12442)
* initial support of llama3.2 * update * update * fix style * fix style * fix * small fix
This commit is contained in:
		
							parent
							
								
									cdd41f5e4c
								
							
						
					
					
						commit
						0e23bd779f
					
				
					 7 changed files with 133 additions and 27 deletions
				
			
		| 
						 | 
				
			
			@ -10,6 +10,7 @@ In this directory, you will find a C++ example on how to run LLM models on Intel
 | 
			
		|||
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
 | 
			
		||||
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
 | 
			
		||||
| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16), [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) |
 | 
			
		||||
| Llama3.2 | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct), [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |
 | 
			
		||||
 | 
			
		||||
## 0. Requirements
 | 
			
		||||
To run this C++ example with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
 | 
			
		||||
| 
						 | 
				
			
			@ -55,6 +56,12 @@ python convert.py --repo-id-or-model-path openbmb/MiniCPM-1B-sft-bf16 --save-di
 | 
			
		|||
 | 
			
		||||
:: to convert MiniCPM-2B-sft-bf16
 | 
			
		||||
python convert.py --repo-id-or-model-path openbmb/MiniCPM-2B-sft-bf16 --save-directory <converted_model_path>
 | 
			
		||||
 | 
			
		||||
:: to convert Llama-3.2-1B-Instruct
 | 
			
		||||
python convert.py --repo-id-or-model-path meta-llama/Llama-3.2-1B-Instruct --save-directory <converted_model_path>
 | 
			
		||||
 | 
			
		||||
:: to convert Llama-3.2-3B-Instruct
 | 
			
		||||
python convert.py --repo-id-or-model-path meta-llama/Llama-3.2-3B-Instruct --save-directory <converted_model_path>
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Arguments info:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,8 +18,12 @@
 | 
			
		|||
import torch
 | 
			
		||||
import argparse
 | 
			
		||||
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
 | 
			
		||||
import transformers
 | 
			
		||||
from transformers import AutoTokenizer
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
from packaging import version
 | 
			
		||||
import os
 | 
			
		||||
import shutil
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -67,7 +71,14 @@ if __name__ == "__main__":
 | 
			
		|||
                                                 save_directory=save_dir)
 | 
			
		||||
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
    tokenizer.save_pretrained(save_dir)
 | 
			
		||||
 | 
			
		||||
    trans_version = transformers.__version__
 | 
			
		||||
    if version.parse(trans_version) >= version.parse("4.45.0"):
 | 
			
		||||
        tokenizer_json = os.path.join(model_path, "tokenizer.json")
 | 
			
		||||
        dst_path = os.path.join(save_dir, "tokenizer.json")
 | 
			
		||||
        shutil.copy(tokenizer_json, dst_path)
 | 
			
		||||
    else:
 | 
			
		||||
        tokenizer.save_pretrained(save_dir)
 | 
			
		||||
 | 
			
		||||
    print("-" * 80)
 | 
			
		||||
    print(f"finish save model to {save_dir}")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
 | 
			
		|||
    for (int i = 1; i < params.n_predict; i++){
 | 
			
		||||
        auto logits = run_decode(model, embd[i-1]);
 | 
			
		||||
        int32_t token = llm_sample_token(logits, true, model_params);
 | 
			
		||||
        if (token != tok_params.eos_token_id) {
 | 
			
		||||
        if (std::find(tok_params.eos_token_id.begin(), tok_params.eos_token_id.end(), token) == tok_params.eos_token_id.end()){
 | 
			
		||||
            embd.push_back(token);
 | 
			
		||||
            token_nums ++;
 | 
			
		||||
        } else {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,7 +46,7 @@ if __name__ == "__main__":
 | 
			
		|||
                        help='Prompt to infer')
 | 
			
		||||
    parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
 | 
			
		||||
    parser.add_argument("--max-context-len", type=int, default=1024)
 | 
			
		||||
    parser.add_argument("--max-prompt-len", type=int, default=960)
 | 
			
		||||
    parser.add_argument("--max-prompt-len", type=int, default=512)
 | 
			
		||||
    parser.add_argument("--quantization_group_size", type=int, default=0)
 | 
			
		||||
    parser.add_argument('--low_bit', type=str, default="sym_int4",
 | 
			
		||||
                        help='Low bit precision to quantize the model')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -71,6 +71,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0,
 | 
			
		||||
        cos_len: int = 1,
 | 
			
		||||
        keep_position_ids=True,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(max_seq_len=max_seq_len,
 | 
			
		||||
                         transpose_value=transpose_value,
 | 
			
		||||
| 
						 | 
				
			
			@ -122,7 +123,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
                                                       self.seq_len),
 | 
			
		||||
                                                      dtype=np.float16)
 | 
			
		||||
        if self.cached_cos is None:
 | 
			
		||||
            if mode == "prefill":
 | 
			
		||||
            if mode == "prefill" and keep_position_ids:
 | 
			
		||||
                position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
			
		||||
            cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim),
 | 
			
		||||
                                       dtype=np.float32)
 | 
			
		||||
| 
						 | 
				
			
			@ -185,12 +186,12 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        hidden_states = input
 | 
			
		||||
 | 
			
		||||
        curr_key_values = []
 | 
			
		||||
        cos_condition = cached_cos is not None or (mode == "prefill" and keep_position_ids)
 | 
			
		||||
        for i in range(num_layers):
 | 
			
		||||
            hidden_states, new_key_states, new_value_states = self.build_decoder(
 | 
			
		||||
                hidden_states=hidden_states,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                position_ids=position_ids if (cached_cos is not None
 | 
			
		||||
                                              or mode == "prefill") else None,
 | 
			
		||||
                position_ids=position_ids if cos_condition else None,
 | 
			
		||||
                input_layernorm_weight=input_layernorm_weights[i],
 | 
			
		||||
                post_attention_layernorm_weight=post_attn_layernorm_weights[i],
 | 
			
		||||
                past_key=past_keys[i],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -456,15 +456,27 @@ def convert_llm_for_deploy(model: torch.nn.Module,
 | 
			
		|||
                           group_size, layernorm_const, "prefill")
 | 
			
		||||
        # save blob of lmhead and bin of embedding
 | 
			
		||||
        convert_lm_head_and_embedding(model, n_splits_linear,
 | 
			
		||||
                                      save_directory, weight_dir, True)
 | 
			
		||||
                                      save_directory, weight_dir,
 | 
			
		||||
                                      convert_model=True)
 | 
			
		||||
    elif model.config.model_type == "llama":
 | 
			
		||||
        layernorm_const = True
 | 
			
		||||
        embedding_post = False
 | 
			
		||||
        cos_sin_input = False
 | 
			
		||||
        use_prefill_sdp = False
 | 
			
		||||
        if model.config.vocab_size == 32000:
 | 
			
		||||
            # for Llama2-7B
 | 
			
		||||
            fused_layers = 4
 | 
			
		||||
            use_prefill_sdp = True
 | 
			
		||||
        else:
 | 
			
		||||
            # for Llama3-8B
 | 
			
		||||
            fused_layers = 2
 | 
			
		||||
            if model.config.intermediate_size == 8192:
 | 
			
		||||
                # llama3.2 1B & # llama3.2 3B
 | 
			
		||||
                embedding_post = True
 | 
			
		||||
                cos_sin_input = True
 | 
			
		||||
                fused_layers = 2
 | 
			
		||||
            else:
 | 
			
		||||
                # for Llama3-8B
 | 
			
		||||
                fused_layers = 2
 | 
			
		||||
                use_prefill_sdp = True
 | 
			
		||||
        update_dict = {"kv_len": kv_len,
 | 
			
		||||
                       "num_head": model.model.layers[0].self_attn.num_heads,
 | 
			
		||||
                       "head_dim": model.model.layers[0].self_attn.head_dim,
 | 
			
		||||
| 
						 | 
				
			
			@ -474,14 +486,21 @@ def convert_llm_for_deploy(model: torch.nn.Module,
 | 
			
		|||
                       "group_size":  group_size,
 | 
			
		||||
                       "fused_layers": fused_layers,
 | 
			
		||||
                       "qkv_bias": False,
 | 
			
		||||
                       "use_prefill_sdp": True,
 | 
			
		||||
                       "use_prefill_sdp": use_prefill_sdp,
 | 
			
		||||
                       "weight_num": 7,
 | 
			
		||||
                       "weight_idx": 5}
 | 
			
		||||
                       "weight_idx": 5,
 | 
			
		||||
                       "embedding_post": embedding_post,
 | 
			
		||||
                       "cos_sin_input": cos_sin_input}
 | 
			
		||||
        model.config.update(update_dict)
 | 
			
		||||
        model.config.save_pretrained(save_directory)
 | 
			
		||||
 | 
			
		||||
        from .llama import convert_llama_layer, convert_fused_llama_layer
 | 
			
		||||
        from .llama import convert_lm_head_and_embedding
 | 
			
		||||
        # save blob of lmhead and bin of embedding & (optional) embedding_post
 | 
			
		||||
        convert_lm_head_and_embedding(model, n_splits_linear,
 | 
			
		||||
                                      save_directory, weight_dir,
 | 
			
		||||
                                      convert_model=True,
 | 
			
		||||
                                      max_prompt_len=max_prompt_len)
 | 
			
		||||
        # save fused_layers blobs of fused decoder layers
 | 
			
		||||
        convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
 | 
			
		||||
                                  save_directory, weight_dir, transpose_value_cache, kv_len,
 | 
			
		||||
| 
						 | 
				
			
			@ -490,9 +509,6 @@ def convert_llm_for_deploy(model: torch.nn.Module,
 | 
			
		|||
        convert_llama_layer(model, 0, n_splits_linear, n_splits_down_proj,
 | 
			
		||||
                            save_directory, weight_dir, transpose_value_cache, max_prompt_len,
 | 
			
		||||
                            group_size, layernorm_const, "prefill")
 | 
			
		||||
        # save blob of lmhead and bin of embedding
 | 
			
		||||
        convert_lm_head_and_embedding(model, n_splits_linear,
 | 
			
		||||
                                      save_directory, weight_dir, True)
 | 
			
		||||
    elif model.config.model_type == "minicpm":
 | 
			
		||||
        layernorm_const = True
 | 
			
		||||
        fused_layers = 4
 | 
			
		||||
| 
						 | 
				
			
			@ -523,6 +539,8 @@ def convert_llm_for_deploy(model: torch.nn.Module,
 | 
			
		|||
        convert_minicpm_layer(model, 0, n_splits_linear, n_splits_down_proj,
 | 
			
		||||
                              save_directory, weight_dir, transpose_value_cache, max_prompt_len,
 | 
			
		||||
                              group_size, layernorm_const, "prefill")
 | 
			
		||||
        # save blob of lmhead and bin of embedding
 | 
			
		||||
        # save blob of lmhead and bin of embedding and embedding_post
 | 
			
		||||
        convert_lm_head_and_embedding(model, n_splits_linear,
 | 
			
		||||
                                      save_directory, weight_dir, True, max_prompt_len)
 | 
			
		||||
                                      save_directory, weight_dir,
 | 
			
		||||
                                      convert_model=True,
 | 
			
		||||
                                      max_prompt_len=max_prompt_len)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -83,8 +83,46 @@ class Llama32Embedding(NNFactory):
 | 
			
		|||
        self.compile()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Llama32PostEmbedding(NNFactory):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        inv_freq,
 | 
			
		||||
        attention_scaling,
 | 
			
		||||
        input_len: int = 1,
 | 
			
		||||
        device: str = "NPU",
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(False, device)
 | 
			
		||||
        self.attention_scaling = attention_scaling
 | 
			
		||||
 | 
			
		||||
        # define input
 | 
			
		||||
        position_ids = self.parameter((1, input_len), dtype=np.int64)
 | 
			
		||||
        inv_freq = self.constant(inv_freq)
 | 
			
		||||
 | 
			
		||||
        # rotary_emb module
 | 
			
		||||
        inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1))
 | 
			
		||||
        position_ids = self.reshape(position_ids, (1, 1, input_len))
 | 
			
		||||
        freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq),
 | 
			
		||||
                                 self.convert_to_fp32(position_ids))
 | 
			
		||||
        freqs = self.transpose(freqs, [0, 2, 1])
 | 
			
		||||
        emb = self.concat(freqs, freqs, axis=2)
 | 
			
		||||
        cos = self.cos(emb)
 | 
			
		||||
        sin = self.sin(emb)
 | 
			
		||||
        cos = cos * self.attention_scaling
 | 
			
		||||
        sin = sin * self.attention_scaling
 | 
			
		||||
        if input_len > 1:
 | 
			
		||||
            cos = self.unsqueeze(cos, [1])
 | 
			
		||||
            sin = self.unsqueeze(sin, [1])
 | 
			
		||||
 | 
			
		||||
        # define outputs
 | 
			
		||||
        cos = self.convert_to_fp32(cos)
 | 
			
		||||
        sin = self.convert_to_fp32(sin)
 | 
			
		||||
 | 
			
		||||
        print("start compiling")
 | 
			
		||||
        self.compile()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
 | 
			
		||||
                                  convert_model=False):
 | 
			
		||||
                                  convert_model=False, max_prompt_len=1):
 | 
			
		||||
    num_heads = model.model.layers[0].self_attn.num_heads
 | 
			
		||||
    num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
 | 
			
		||||
    head_dim = model.model.layers[0].self_attn.head_dim
 | 
			
		||||
| 
						 | 
				
			
			@ -145,6 +183,13 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
 | 
			
		|||
            padding_idx=model.config.pad_token_id,
 | 
			
		||||
            dtype=np.float16,
 | 
			
		||||
        )
 | 
			
		||||
        if convert_model:
 | 
			
		||||
            bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
 | 
			
		||||
            embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
 | 
			
		||||
            first_blob_path = None
 | 
			
		||||
        else:
 | 
			
		||||
            first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
 | 
			
		||||
                                                                 temp_dir, True, False)
 | 
			
		||||
    else:
 | 
			
		||||
        # llama-3.2-3B & llama-3.2-1B
 | 
			
		||||
        embedding_layer = model.model.embed_tokens
 | 
			
		||||
| 
						 | 
				
			
			@ -157,13 +202,27 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
 | 
			
		|||
            attention_scaling=model.model.rotary_emb.attention_scaling,
 | 
			
		||||
            dtype=np.float16,
 | 
			
		||||
        )
 | 
			
		||||
    if convert_model:
 | 
			
		||||
        bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
 | 
			
		||||
        embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
 | 
			
		||||
        first_blob_path = None
 | 
			
		||||
    else:
 | 
			
		||||
        first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
 | 
			
		||||
                                                             temp_dir)
 | 
			
		||||
        if convert_model:
 | 
			
		||||
            bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
 | 
			
		||||
            embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
 | 
			
		||||
            first_blob_path = None
 | 
			
		||||
            # save embedding post module
 | 
			
		||||
            inv_freq = model.model.rotary_emb.inv_freq.to(torch.float16)
 | 
			
		||||
            attention_scaling = model.model.rotary_emb.attention_scaling
 | 
			
		||||
            embedding_post = Llama32PostEmbedding(inv_freq=inv_freq,
 | 
			
		||||
                                                  attention_scaling=attention_scaling,
 | 
			
		||||
                                                  input_len=1)
 | 
			
		||||
            update_names_of_IR_and_export_blob(embedding_post, "embedding_post",
 | 
			
		||||
                                               temp_dir, True, False)
 | 
			
		||||
            embedding_post_prefill = Llama32PostEmbedding(inv_freq=inv_freq,
 | 
			
		||||
                                                          attention_scaling=attention_scaling,
 | 
			
		||||
                                                          input_len=max_prompt_len)
 | 
			
		||||
            update_names_of_IR_and_export_blob(embedding_post_prefill,
 | 
			
		||||
                                               "embedding_post_prefill",
 | 
			
		||||
                                               temp_dir, True, False)
 | 
			
		||||
        else:
 | 
			
		||||
            first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
 | 
			
		||||
                                                                 temp_dir)
 | 
			
		||||
    return first_blob_path, last_blob_path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -212,10 +271,12 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
    if mode == "decode":
 | 
			
		||||
        input_len = 1
 | 
			
		||||
        decoder_name = f"decoder_layer_{layer_idx}"
 | 
			
		||||
        keep_position_ids = True
 | 
			
		||||
    else:
 | 
			
		||||
        input_len = kv_len
 | 
			
		||||
        decoder_name = "decoder_layer_prefill"
 | 
			
		||||
        layernorm_const = False
 | 
			
		||||
        keep_position_ids = False
 | 
			
		||||
 | 
			
		||||
    single_decoder = LowBitLlamaMultiDecoderlayer(
 | 
			
		||||
        [1, input_len, num_heads * head_dim],
 | 
			
		||||
| 
						 | 
				
			
			@ -234,7 +295,9 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
        dtype=np_dtype,
 | 
			
		||||
        n_splits_linear=n_splits_linear,
 | 
			
		||||
        n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
        group_size=group_size
 | 
			
		||||
        group_size=group_size,
 | 
			
		||||
        cos_len=input_len,
 | 
			
		||||
        keep_position_ids=keep_position_ids
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
 | 
			
		||||
| 
						 | 
				
			
			@ -309,8 +372,14 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow
 | 
			
		|||
                    scales.append(l.scale)
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
            cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
            cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
            if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
			
		||||
                # llama-2-7B & llama-3-8B
 | 
			
		||||
                cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
                cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
            else:
 | 
			
		||||
                # llama-3.2-3B & llama-3.2-1B
 | 
			
		||||
                cached_cos = None
 | 
			
		||||
                cached_sin = None
 | 
			
		||||
            layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
			
		||||
            layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue