Add minicpm-2b in L0 pipeline (#12308)
This commit is contained in:
		
							parent
							
								
									b9853f98b3
								
							
						
					
					
						commit
						eda764909c
					
				
					 4 changed files with 90 additions and 20 deletions
				
			
		| 
						 | 
					@ -11,7 +11,7 @@ In this directory, you will find examples on how to directly run HuggingFace `tr
 | 
				
			||||||
| Qwen2 | [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
 | 
					| Qwen2 | [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
 | 
				
			||||||
| Qwen2.5 | [Qwen/Qwen2.5-7b-Instruct](https://huggingface.co/Qwen/Qwen2.5-7b-Instruct) |
 | 
					| Qwen2.5 | [Qwen/Qwen2.5-7b-Instruct](https://huggingface.co/Qwen/Qwen2.5-7b-Instruct) |
 | 
				
			||||||
| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) |
 | 
					| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) |
 | 
				
			||||||
| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16) |
 | 
					| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16), [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## 0. Requirements
 | 
					## 0. Requirements
 | 
				
			||||||
To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
 | 
					To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
 | 
				
			||||||
| 
						 | 
					@ -59,6 +59,9 @@ python baichuan2.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
:: to run MiniCPM-1B-sft-bf16
 | 
					:: to run MiniCPM-1B-sft-bf16
 | 
				
			||||||
python minicpm.py
 | 
					python minicpm.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:: to run MiniCPM-2B-sft-bf16
 | 
				
			||||||
 | 
					python minicpm.py --repo-id-or-model-path "openbmb/MiniCPM-2B-sft-bf16"
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Arguments info:
 | 
					Arguments info:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -32,7 +32,7 @@ if __name__ == "__main__":
 | 
				
			||||||
    parser.add_argument(
 | 
					    parser.add_argument(
 | 
				
			||||||
        "--repo-id-or-model-path",
 | 
					        "--repo-id-or-model-path",
 | 
				
			||||||
        type=str,
 | 
					        type=str,
 | 
				
			||||||
        default="openbmb/MiniCPM-1B-sft-bf16",
 | 
					        default="openbmb/MiniCPM-1B-sft-bf16", # or "openbmb/MiniCPM-2B-sft-bf16"
 | 
				
			||||||
        help="The huggingface repo id for the MiniCPM model to be downloaded"
 | 
					        help="The huggingface repo id for the MiniCPM model to be downloaded"
 | 
				
			||||||
        ", or the path to the huggingface checkpoint folder",
 | 
					        ", or the path to the huggingface checkpoint folder",
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -362,7 +362,7 @@ def convert_llm(model: torch.nn.Module,
 | 
				
			||||||
        invalidInputError(False, "Now we only support Llama2 / Llama3 / Baichuan2 / "
 | 
					        invalidInputError(False, "Now we only support Llama2 / Llama3 / Baichuan2 / "
 | 
				
			||||||
                                 "Qwen2 / Qwen2.5 / Minicpm for pipeline running.")
 | 
					                                 "Qwen2 / Qwen2.5 / Minicpm for pipeline running.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if isinstance(model.lm_head, SlicedLMHead):
 | 
					    if hasattr(model, "lm_head") and isinstance(model.lm_head, SlicedLMHead):
 | 
				
			||||||
        model.lm_head.get_fused_lm_head()
 | 
					        model.lm_head.get_fused_lm_head()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # patch generate function
 | 
					    # patch generate function
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,8 +18,10 @@
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from .common import update_names_of_IR_and_export_blob, LowBitLLMLMHead
 | 
					from .common import update_names_of_IR_and_export_blob
 | 
				
			||||||
from intel_npu_acceleration_library.backend.factory import NNFactory
 | 
					from intel_npu_acceleration_library.backend.factory import NNFactory
 | 
				
			||||||
 | 
					from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
 | 
				
			||||||
 | 
					from typing import Sequence
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MiniCPMEmbedding(NNFactory):
 | 
					class MiniCPMEmbedding(NNFactory):
 | 
				
			||||||
| 
						 | 
					@ -65,6 +67,68 @@ class MiniCPMEmbedding(NNFactory):
 | 
				
			||||||
        self.compile()
 | 
					        self.compile()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MiniCPMLMHead(LLMBaseNNFactory):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        hidden_shape: Sequence[int],
 | 
				
			||||||
 | 
					        num_heads: int,
 | 
				
			||||||
 | 
					        rms_norm_eps: float,
 | 
				
			||||||
 | 
					        model_norm_weight,
 | 
				
			||||||
 | 
					        vocab_size: int,
 | 
				
			||||||
 | 
					        mode: str = "decode",
 | 
				
			||||||
 | 
					        dtype: np.dtype = np.int8,
 | 
				
			||||||
 | 
					        max_seq_len: int = 1024,
 | 
				
			||||||
 | 
					        transpose_value: bool = False,
 | 
				
			||||||
 | 
					        profile: bool = False,
 | 
				
			||||||
 | 
					        device: str = "NPU",
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__(max_seq_len=max_seq_len,
 | 
				
			||||||
 | 
					                         transpose_value=transpose_value,
 | 
				
			||||||
 | 
					                         dtype=dtype,
 | 
				
			||||||
 | 
					                         profile=profile,
 | 
				
			||||||
 | 
					                         device=device)
 | 
				
			||||||
 | 
					        self.max_seq_len = max_seq_len
 | 
				
			||||||
 | 
					        self.dtype = dtype
 | 
				
			||||||
 | 
					        self.batch_size, self.seq_len, self.hidden_size = hidden_shape
 | 
				
			||||||
 | 
					        self.mode = mode
 | 
				
			||||||
 | 
					        self.rms_norm_eps = rms_norm_eps
 | 
				
			||||||
 | 
					        self.transpose_value = transpose_value
 | 
				
			||||||
 | 
					        self.vocab_size = vocab_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.num_heads = num_heads
 | 
				
			||||||
 | 
					        self.head_dim = self.hidden_size // self.num_heads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # define input, the order self.parameter matters
 | 
				
			||||||
 | 
					        input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = input
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # model norm and lm head
 | 
				
			||||||
 | 
					        model_norm_weight = self.constant(model_norm_weight)
 | 
				
			||||||
 | 
					        hidden_states = self.layer_norm(hidden_states, model_norm_weight)
 | 
				
			||||||
 | 
					        if vocab_size == 122753:
 | 
				
			||||||
 | 
					            # for MiniCPM-2B-sft-bf16
 | 
				
			||||||
 | 
					            hidden_states_1 = self.linear(
 | 
				
			||||||
 | 
					                hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            hidden_states_2 = self.linear(
 | 
				
			||||||
 | 
					                hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313])
 | 
				
			||||||
 | 
					            hidden_states = self.concat(hidden_states_1, hidden_states_2, axis=2)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # for MiniCPM-1B-sft-bf16
 | 
				
			||||||
 | 
					            hidden_states = self.linear(
 | 
				
			||||||
 | 
					                hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # define outputs
 | 
				
			||||||
 | 
					        hidden_states = self.convert_to_fp32(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print("start compiling")
 | 
				
			||||||
 | 
					        self.compile()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
					def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
				
			||||||
    num_heads = model.model.layers[0].self_attn.num_heads
 | 
					    num_heads = model.model.layers[0].self_attn.num_heads
 | 
				
			||||||
    num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
 | 
					    num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
 | 
				
			||||||
| 
						 | 
					@ -72,24 +136,23 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
				
			||||||
    rms_norm_eps = model.config.rms_norm_eps
 | 
					    rms_norm_eps = model.config.rms_norm_eps
 | 
				
			||||||
    vocab_size = model.config.vocab_size
 | 
					    vocab_size = model.config.vocab_size
 | 
				
			||||||
    model_norm = model.model.norm
 | 
					    model_norm = model.model.norm
 | 
				
			||||||
    lm_head = model.lm_head
 | 
					 | 
				
			||||||
    if n_splits_linear == 1:
 | 
					    if n_splits_linear == 1:
 | 
				
			||||||
        weights = [(lm_head.weight, lm_head.scale)]
 | 
					        if vocab_size == 122753:
 | 
				
			||||||
 | 
					            # for MiniCPM-2B-sft-bf16
 | 
				
			||||||
 | 
					            weights = [(model.lm_head_0.weight, model.lm_head_0.scale),
 | 
				
			||||||
 | 
					                       (model.lm_head_1.weight, model.lm_head_1.scale)]
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
        lm_heads = lm_head.lm_heads
 | 
					            # for MiniCPM-1B-sft-bf16
 | 
				
			||||||
        lm_head_weights = []
 | 
					            weights = [(model.lm_head.weight, model.lm_head.scale)]
 | 
				
			||||||
        scales = []
 | 
					    else:
 | 
				
			||||||
        for i in range(n_splits_linear):
 | 
					        # TODO
 | 
				
			||||||
            lm_head_weights.append(lm_heads[i].weight)
 | 
					        pass
 | 
				
			||||||
            scales.append(lm_heads[i].scale)
 | 
					 | 
				
			||||||
        weights = [(torch.stack(lm_head_weights, axis=0),
 | 
					 | 
				
			||||||
                    torch.stack(scales, axis=0))]
 | 
					 | 
				
			||||||
    if isinstance(weights[0], tuple):
 | 
					    if isinstance(weights[0], tuple):
 | 
				
			||||||
        np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
 | 
					        np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
 | 
				
			||||||
    else:  # FP16 Linear
 | 
					    else:  # FP16 Linear
 | 
				
			||||||
        np_dtype = np.float16
 | 
					        np_dtype = np.float16
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    new_lm_head = LowBitLLMLMHead(
 | 
					    new_lm_head = MiniCPMLMHead(
 | 
				
			||||||
        [1, 1, num_heads * head_dim],
 | 
					        [1, 1, num_heads * head_dim],
 | 
				
			||||||
        num_heads=num_heads,
 | 
					        num_heads=num_heads,
 | 
				
			||||||
        max_seq_len=1,
 | 
					        max_seq_len=1,
 | 
				
			||||||
| 
						 | 
					@ -99,17 +162,21 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
				
			||||||
        dtype=np_dtype,
 | 
					        dtype=np_dtype,
 | 
				
			||||||
        model_norm_weight=model_norm.weight.to(torch.float16),
 | 
					        model_norm_weight=model_norm.weight.to(torch.float16),
 | 
				
			||||||
        vocab_size=vocab_size,
 | 
					        vocab_size=vocab_size,
 | 
				
			||||||
        n_splits=n_splits_linear
 | 
					 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
 | 
					    last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # save weights bins files
 | 
					    # save weights bins files
 | 
				
			||||||
    if n_splits_linear == 1:
 | 
					    if n_splits_linear == 1:
 | 
				
			||||||
        weight_numpy = [
 | 
					        if vocab_size == 122753:
 | 
				
			||||||
            lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
 | 
					            weight_numpy = [model.lm_head_0.weight.data.numpy(),
 | 
				
			||||||
        ]
 | 
					                            model.lm_head_0.scale.data.numpy(),
 | 
				
			||||||
 | 
					                            model.lm_head_1.weight.data.numpy(),
 | 
				
			||||||
 | 
					                            model.lm_head_1.scale.data.numpy(), ]
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
        weight_numpy = [v.numpy() for v in weights[0]]
 | 
					            weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(), ]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # TODO
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for idx, weight in enumerate(weight_numpy):
 | 
					    for idx, weight in enumerate(weight_numpy):
 | 
				
			||||||
        bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
 | 
					        bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue