support Llama2-7B / Llama3-8B for NPU C++ (#12431)
* support llama2 * update * support fused_layers=4 for Llama2-7B
This commit is contained in:
		
							parent
							
								
									4ffa6c752c
								
							
						
					
					
						commit
						0819fad34e
					
				
					 4 changed files with 186 additions and 38 deletions
				
			
		| 
						 | 
				
			
			@ -7,7 +7,8 @@ In this directory, you will find a C++ example on how to run LLM models on Intel
 | 
			
		|||
|------------|----------------------------------------------------------------|
 | 
			
		||||
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
 | 
			
		||||
| Qwen2.5 | [Qwen/Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) |
 | 
			
		||||
 | 
			
		||||
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
 | 
			
		||||
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
 | 
			
		||||
 | 
			
		||||
## 0. Requirements
 | 
			
		||||
To run this C++ example with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -436,7 +436,11 @@ def convert_llm_for_deploy(model: torch.nn.Module,
 | 
			
		|||
                       "max_prompt_len": max_prompt_len,
 | 
			
		||||
                       "layernorm_const": layernorm_const,
 | 
			
		||||
                       "group_size":  group_size,
 | 
			
		||||
                       "fused_layers": fused_layers}
 | 
			
		||||
                       "fused_layers": fused_layers,
 | 
			
		||||
                       "qkv_bias": True,
 | 
			
		||||
                       "use_prefill_sdp": False,
 | 
			
		||||
                       "weight_num": 7,
 | 
			
		||||
                       "weight_idx": 8}
 | 
			
		||||
        model.config.update(update_dict)
 | 
			
		||||
        model.config.save_pretrained(save_directory)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -453,3 +457,39 @@ def convert_llm_for_deploy(model: torch.nn.Module,
 | 
			
		|||
        # save blob of lmhead and bin of embedding
 | 
			
		||||
        convert_lm_head_and_embedding(model, n_splits_linear,
 | 
			
		||||
                                      save_directory, weight_dir, True)
 | 
			
		||||
    elif model.config.model_type == "llama":
 | 
			
		||||
        layernorm_const = True
 | 
			
		||||
        if model.config.vocab_size == 32000:
 | 
			
		||||
            # for Llama2-7B
 | 
			
		||||
            fused_layers = 4
 | 
			
		||||
        else:
 | 
			
		||||
            # for Llama3-8B
 | 
			
		||||
            fused_layers = 2
 | 
			
		||||
        update_dict = {"kv_len": kv_len,
 | 
			
		||||
                       "num_head": model.model.layers[0].self_attn.num_heads,
 | 
			
		||||
                       "head_dim": model.model.layers[0].self_attn.head_dim,
 | 
			
		||||
                       "transpose_value_cache": transpose_value_cache,
 | 
			
		||||
                       "max_prompt_len": max_prompt_len,
 | 
			
		||||
                       "layernorm_const": layernorm_const,
 | 
			
		||||
                       "group_size":  group_size,
 | 
			
		||||
                       "fused_layers": fused_layers,
 | 
			
		||||
                       "qkv_bias": False,
 | 
			
		||||
                       "use_prefill_sdp": True,
 | 
			
		||||
                       "weight_num": 7,
 | 
			
		||||
                       "weight_idx": 5}
 | 
			
		||||
        model.config.update(update_dict)
 | 
			
		||||
        model.config.save_pretrained(save_directory)
 | 
			
		||||
 | 
			
		||||
        from .llama import convert_llama_layer, convert_fused_llama_layer
 | 
			
		||||
        from .llama import convert_lm_head_and_embedding
 | 
			
		||||
        # save fused_layers blobs of fused decoder layers
 | 
			
		||||
        convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
 | 
			
		||||
                                  save_directory, weight_dir, transpose_value_cache, kv_len,
 | 
			
		||||
                                  group_size, layernorm_const, "decode")
 | 
			
		||||
        # save blob of single prefill layer
 | 
			
		||||
        convert_llama_layer(model, 0, n_splits_linear, n_splits_down_proj,
 | 
			
		||||
                            save_directory, weight_dir, transpose_value_cache, max_prompt_len,
 | 
			
		||||
                            group_size, layernorm_const, "prefill")
 | 
			
		||||
        # save blob of lmhead and bin of embedding
 | 
			
		||||
        convert_lm_head_and_embedding(model, n_splits_linear,
 | 
			
		||||
                                      save_directory, weight_dir, True)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -83,7 +83,8 @@ class Llama32Embedding(NNFactory):
 | 
			
		|||
        self.compile()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
			
		||||
def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
 | 
			
		||||
                                  convert_model=False):
 | 
			
		||||
    num_heads = model.model.layers[0].self_attn.num_heads
 | 
			
		||||
    num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
 | 
			
		||||
    head_dim = model.model.layers[0].self_attn.head_dim
 | 
			
		||||
| 
						 | 
				
			
			@ -119,7 +120,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
			
		|||
        vocab_size=vocab_size,
 | 
			
		||||
        n_splits=n_splits_linear
 | 
			
		||||
    )
 | 
			
		||||
    last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
 | 
			
		||||
    last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir,
 | 
			
		||||
                                                        True, False)
 | 
			
		||||
 | 
			
		||||
    # save weights bins files
 | 
			
		||||
    if n_splits_linear == 1:
 | 
			
		||||
| 
						 | 
				
			
			@ -154,14 +156,19 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
			
		|||
            attention_scaling=model.model.rotary_emb.attention_scaling,
 | 
			
		||||
            dtype=np.float16,
 | 
			
		||||
        )
 | 
			
		||||
    first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
 | 
			
		||||
                                                         temp_dir)
 | 
			
		||||
    if convert_model:
 | 
			
		||||
        bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
 | 
			
		||||
        embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
 | 
			
		||||
        first_blob_path = None
 | 
			
		||||
    else:
 | 
			
		||||
        first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
 | 
			
		||||
                                                             temp_dir)
 | 
			
		||||
    return first_blob_path, last_blob_path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		||||
                        temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
 | 
			
		||||
                        layernorm_const):
 | 
			
		||||
                        layernorm_const, mode="decode"):
 | 
			
		||||
    num_heads = model.model.layers[0].self_attn.num_heads
 | 
			
		||||
    num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
 | 
			
		||||
    head_dim = model.model.layers[0].self_attn.head_dim
 | 
			
		||||
| 
						 | 
				
			
			@ -201,8 +208,16 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
    else:  # FP16 Linear
 | 
			
		||||
        np_dtype = np.float16
 | 
			
		||||
 | 
			
		||||
    if mode == "decode":
 | 
			
		||||
        input_len = 1
 | 
			
		||||
        decoder_name = f"decoder_layer_{layer_idx}"
 | 
			
		||||
    else:
 | 
			
		||||
        input_len = kv_len
 | 
			
		||||
        decoder_name = "decoder_layer_prefill"
 | 
			
		||||
        layernorm_const = False
 | 
			
		||||
 | 
			
		||||
    single_decoder = LowBitLlamaMultiDecoderlayer(
 | 
			
		||||
        [1, 1, num_heads * head_dim],
 | 
			
		||||
        [1, input_len, num_heads * head_dim],
 | 
			
		||||
        input_layernorm_weights=[layer_norm_0] if layernorm_const else None,
 | 
			
		||||
        post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None,
 | 
			
		||||
        cached_cos=cached_cos,
 | 
			
		||||
| 
						 | 
				
			
			@ -213,40 +228,136 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
        max_seq_len=kv_len,
 | 
			
		||||
        rms_norm_eps=rms_norm_eps,
 | 
			
		||||
        intermediate_size=intermediate_size,
 | 
			
		||||
        mode="decode",
 | 
			
		||||
        mode=mode,
 | 
			
		||||
        transpose_value=transpose_value_cache,
 | 
			
		||||
        dtype=np_dtype,
 | 
			
		||||
        n_splits_linear=n_splits_linear,
 | 
			
		||||
        n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
        group_size=group_size
 | 
			
		||||
    )
 | 
			
		||||
    rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
 | 
			
		||||
                                                        f"decoder_layer_{layer_idx}",
 | 
			
		||||
                                                        temp_dir)
 | 
			
		||||
 | 
			
		||||
    if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
			
		||||
        # llama-2-7B & llama-3-8B
 | 
			
		||||
        if layernorm_const:
 | 
			
		||||
            st_idx = 5
 | 
			
		||||
    rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
 | 
			
		||||
                                                        decoder_name,
 | 
			
		||||
                                                        temp_dir,
 | 
			
		||||
                                                        True, False)
 | 
			
		||||
 | 
			
		||||
    if mode == "decode":
 | 
			
		||||
        if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
			
		||||
            # llama-2-7B & llama-3-8B
 | 
			
		||||
            if layernorm_const:
 | 
			
		||||
                st_idx = 5
 | 
			
		||||
            else:
 | 
			
		||||
                input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
 | 
			
		||||
                post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
			
		||||
                layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
			
		||||
                layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
			
		||||
                st_idx = 7
 | 
			
		||||
        else:
 | 
			
		||||
            # llama-3.2-3B & llama-3.2-1B
 | 
			
		||||
            if layernorm_const:
 | 
			
		||||
                st_idx = 6
 | 
			
		||||
            else:
 | 
			
		||||
                input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
			
		||||
                post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_5.bin")
 | 
			
		||||
                layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
			
		||||
                layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
			
		||||
                st_idx = 8
 | 
			
		||||
        for idx, (weight, scale) in enumerate(weights):
 | 
			
		||||
            bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
			
		||||
            weight.numpy().tofile(bin_file)
 | 
			
		||||
            bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
 | 
			
		||||
            scale.numpy().tofile(bin_file)
 | 
			
		||||
        del single_decoder
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
 | 
			
		||||
                              save_dir, weight_dir, transpose_value_cache, kv_len, group_size,
 | 
			
		||||
                              layernorm_const, mode="decode"):
 | 
			
		||||
    num_heads = model.model.layers[0].self_attn.num_heads
 | 
			
		||||
    num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
 | 
			
		||||
    head_dim = model.model.layers[0].self_attn.head_dim
 | 
			
		||||
    intermediate_size = model.config.intermediate_size
 | 
			
		||||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    layer_num = len(model.model.layers)
 | 
			
		||||
    fused_layer_num = layer_num // fused_layers
 | 
			
		||||
 | 
			
		||||
    from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer
 | 
			
		||||
    for i in range(fused_layers):
 | 
			
		||||
        layer_start = i * fused_layer_num
 | 
			
		||||
        layer_end = min((i + 1) * fused_layer_num, layer_num)
 | 
			
		||||
        layer_weights = []
 | 
			
		||||
        input_layer_norm_weights = []
 | 
			
		||||
        post_attn_layernorm_weights = []
 | 
			
		||||
        layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
        n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
        n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
        for layer_idx in layer_indexs:
 | 
			
		||||
            curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
            attn_layer = curr_layer.self_attn
 | 
			
		||||
            mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
            weights = []
 | 
			
		||||
            for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                               attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                               mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
 | 
			
		||||
                               mlp_layer.down_proj_dq_list]:
 | 
			
		||||
                l_weights = []
 | 
			
		||||
                scales = []
 | 
			
		||||
                for l in layer_list:
 | 
			
		||||
                    l_weights.append(l.weight)
 | 
			
		||||
                    scales.append(l.scale)
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
            cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
            cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
            layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
			
		||||
            layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
 | 
			
		||||
 | 
			
		||||
            layer_weights.extend(weights)
 | 
			
		||||
            input_layer_norm_weights.append(layer_norm_0)
 | 
			
		||||
            post_attn_layernorm_weights.append(layer_norm_1)
 | 
			
		||||
 | 
			
		||||
            # save weight
 | 
			
		||||
            input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
 | 
			
		||||
            post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
			
		||||
            layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
			
		||||
            layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
			
		||||
            st_idx = 7
 | 
			
		||||
    else:
 | 
			
		||||
        # llama-3.2-3B & llama-3.2-1B
 | 
			
		||||
        if layernorm_const:
 | 
			
		||||
            st_idx = 6
 | 
			
		||||
        else:
 | 
			
		||||
            input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
			
		||||
            post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_5.bin")
 | 
			
		||||
            layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
			
		||||
            layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
			
		||||
            st_idx = 8
 | 
			
		||||
    for idx, (weight, scale) in enumerate(weights):
 | 
			
		||||
        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
			
		||||
        weight.numpy().tofile(bin_file)
 | 
			
		||||
        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
 | 
			
		||||
        scale.numpy().tofile(bin_file)
 | 
			
		||||
    del single_decoder
 | 
			
		||||
            st_idx = 5
 | 
			
		||||
            # 6, 7 are past k/v
 | 
			
		||||
            for idx, (weight, scale) in enumerate(weights):
 | 
			
		||||
                bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
			
		||||
                weight.numpy().tofile(bin_file)
 | 
			
		||||
                bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                        f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
 | 
			
		||||
                scale.numpy().tofile(bin_file)
 | 
			
		||||
 | 
			
		||||
        if isinstance(weights[0], tuple):
 | 
			
		||||
            np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
 | 
			
		||||
        else:  # FP16 Linear
 | 
			
		||||
            np_dtype = np.float16
 | 
			
		||||
 | 
			
		||||
        fused_decoder = LowBitLlamaMultiDecoderlayer(
 | 
			
		||||
            [1, 1, num_heads * head_dim],
 | 
			
		||||
            input_layernorm_weights=input_layer_norm_weights,
 | 
			
		||||
            post_attn_layernorm_weights=post_attn_layernorm_weights,
 | 
			
		||||
            cached_cos=cached_cos,
 | 
			
		||||
            cached_sin=cached_sin,
 | 
			
		||||
            num_heads=num_heads,
 | 
			
		||||
            num_key_value_heads=num_key_value_heads,
 | 
			
		||||
            num_layers=fused_layer_num,
 | 
			
		||||
            max_seq_len=kv_len,
 | 
			
		||||
            rms_norm_eps=rms_norm_eps,
 | 
			
		||||
            intermediate_size=intermediate_size,
 | 
			
		||||
            mode=mode,
 | 
			
		||||
            transpose_value=transpose_value_cache,
 | 
			
		||||
            dtype=np_dtype,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
        )
 | 
			
		||||
        update_names_of_IR_and_export_blob(fused_decoder,
 | 
			
		||||
                                           f"decoder_layer_{i}",
 | 
			
		||||
                                           save_dir,
 | 
			
		||||
                                           compile_blob=True,
 | 
			
		||||
                                           keep_ir=False)
 | 
			
		||||
    return 0
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -135,13 +135,9 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
    if mode == "decode":
 | 
			
		||||
        input_len = 1
 | 
			
		||||
        decoder_name = f"decoder_layer_{layer_idx}"
 | 
			
		||||
        compile = True
 | 
			
		||||
        keep_ir = True
 | 
			
		||||
    else:
 | 
			
		||||
        input_len = kv_len
 | 
			
		||||
        decoder_name = "decoder_layer_prefill"
 | 
			
		||||
        compile = True
 | 
			
		||||
        keep_ir = False
 | 
			
		||||
    single_decoder = LowBitQwenMultiDecoderlayer(
 | 
			
		||||
        [1, input_len, num_heads * head_dim],
 | 
			
		||||
        input_layernorm_weights=None,
 | 
			
		||||
| 
						 | 
				
			
			@ -166,7 +162,7 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
    )
 | 
			
		||||
    rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
 | 
			
		||||
                                                        decoder_name,
 | 
			
		||||
                                                        temp_dir, compile, keep_ir)
 | 
			
		||||
                                                        temp_dir, True, False)
 | 
			
		||||
 | 
			
		||||
    # 0, 1, 2 are input_embed/attention_mask/position_id
 | 
			
		||||
    if mode == "decode":
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue