[NPU L0] Support llama3.2 in L0 pipeline (#12361)
This commit is contained in:
		
							parent
							
								
									7ef7696956
								
							
						
					
					
						commit
						812d5cc32e
					
				
					 5 changed files with 135 additions and 29 deletions
				
			
		| 
						 | 
					@ -8,6 +8,7 @@ In this directory, you will find examples on how to directly run HuggingFace `tr
 | 
				
			||||||
|------------|----------------------------------------------------------------|
 | 
					|------------|----------------------------------------------------------------|
 | 
				
			||||||
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
 | 
					| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
 | 
				
			||||||
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
 | 
					| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
 | 
				
			||||||
 | 
					| Llama3.2 | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct), [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |
 | 
				
			||||||
| Qwen2 | [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
 | 
					| Qwen2 | [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
 | 
				
			||||||
| Qwen2.5 | [Qwen/Qwen2.5-7b-Instruct](https://huggingface.co/Qwen/Qwen2.5-7b-Instruct) |
 | 
					| Qwen2.5 | [Qwen/Qwen2.5-7b-Instruct](https://huggingface.co/Qwen/Qwen2.5-7b-Instruct) |
 | 
				
			||||||
| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) |
 | 
					| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) |
 | 
				
			||||||
| 
						 | 
					@ -28,6 +29,9 @@ conda activate llm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
:: install ipex-llm with 'npu' option
 | 
					:: install ipex-llm with 'npu' option
 | 
				
			||||||
pip install --pre --upgrade ipex-llm[npu]
 | 
					pip install --pre --upgrade ipex-llm[npu]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:: [optional] for Llama-3.2-1B-Instruct & Llama-3.2-3B-Instruct
 | 
				
			||||||
 | 
					pip install transformers==4.45.0 accelerate==0.33.0
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## 2. Runtime Configurations
 | 
					## 2. Runtime Configurations
 | 
				
			||||||
| 
						 | 
					@ -48,6 +52,12 @@ python llama2.py
 | 
				
			||||||
:: to run Meta-Llama-3-8B-Instruct
 | 
					:: to run Meta-Llama-3-8B-Instruct
 | 
				
			||||||
python llama3.py
 | 
					python llama3.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:: to run Llama-3.2-1B-Instruct
 | 
				
			||||||
 | 
					python llama3.py --repo-id-or-model-path "meta-llama/Llama-3.2-1B-Instruct"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:: to run Llama-3.2-3B-Instruct
 | 
				
			||||||
 | 
					python llama3.py --repo-id-or-model-path "meta-llama/Llama-3.2-3B-Instruct"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
:: to run Qwen2.5-7b-Instruct
 | 
					:: to run Qwen2.5-7b-Instruct
 | 
				
			||||||
python qwen.py
 | 
					python qwen.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -124,11 +124,12 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
				
			||||||
        if self.cached_cos is None:
 | 
					        if self.cached_cos is None:
 | 
				
			||||||
            if mode == "prefill":
 | 
					            if mode == "prefill":
 | 
				
			||||||
                position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
					                position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
				
			||||||
                self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
 | 
					            cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim),
 | 
				
			||||||
                self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
 | 
					                                       dtype=np.float32)
 | 
				
			||||||
            else:
 | 
					            self.cos = self.convert_to_fp16(cos)
 | 
				
			||||||
                self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
 | 
					            sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim),
 | 
				
			||||||
                self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim))
 | 
					                                       dtype=np.float32)
 | 
				
			||||||
 | 
					            self.sin = self.convert_to_fp16(sin)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
					            position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
				
			||||||
            cos = self.constant(self.cached_cos)
 | 
					            cos = self.constant(self.cached_cos)
 | 
				
			||||||
| 
						 | 
					@ -367,7 +368,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.cached_cos is None:
 | 
					        if self.cached_cos is None:
 | 
				
			||||||
            inputs += (cos.to(torch.float16), sin.to(torch.float16))
 | 
					            inputs += (cos.to(torch.float32), sin.to(torch.float32))
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            inputs += (position_ids.to(torch.int64),)
 | 
					            inputs += (position_ids.to(torch.int64),)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -496,7 +497,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
				
			||||||
                      attention_mask.to(torch.int64),
 | 
					                      attention_mask.to(torch.int64),
 | 
				
			||||||
                      position_ids.to(torch.int64))
 | 
					                      position_ids.to(torch.int64))
 | 
				
			||||||
        if self.cached_cos is None:
 | 
					        if self.cached_cos is None:
 | 
				
			||||||
            inputs += (cos.to(torch.float16), sin.to(torch.float16),)
 | 
					            inputs += (cos.to(torch.float32), sin.to(torch.float32),)
 | 
				
			||||||
        inputs += (self.layer_norm_0, self.layer_norm_1)
 | 
					        inputs += (self.layer_norm_0, self.layer_norm_1)
 | 
				
			||||||
        hidden_states, past_key, past_value = run_model(
 | 
					        hidden_states, past_key, past_value = run_model(
 | 
				
			||||||
            inputs, self.op_parameters, backend_cls, self.op_id, replica=2
 | 
					            inputs, self.op_parameters, backend_cls, self.op_id, replica=2
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -54,8 +54,7 @@ def run_model(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Reshape input
 | 
					    # Reshape input
 | 
				
			||||||
    input_dtype = x[0].dtype
 | 
					    input_dtype = x[0].dtype
 | 
				
			||||||
    x_np = [set_contiguous(elem).numpy() if elem.dtype == torch.int64 else
 | 
					    x_np = [set_contiguous(elem).numpy() for elem in x]
 | 
				
			||||||
            set_contiguous(elem).to(torch.float16).numpy() for elem in x]
 | 
					 | 
				
			||||||
    op_args = []
 | 
					    op_args = []
 | 
				
			||||||
    op_args_flatten = []
 | 
					    op_args_flatten = []
 | 
				
			||||||
    for w in weights:
 | 
					    for w in weights:
 | 
				
			||||||
| 
						 | 
					@ -651,8 +650,7 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def run_decoders(inputs, decoders, models_ptr=None):
 | 
					    def run_decoders(inputs, decoders, models_ptr=None):
 | 
				
			||||||
        x_np = [elem.numpy() if elem.dtype == torch.int64 else
 | 
					        x_np = [elem.numpy() for elem in inputs]
 | 
				
			||||||
                elem.to(torch.float16).numpy() for elem in inputs]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        num_decoders = len(decoders)
 | 
					        num_decoders = len(decoders)
 | 
				
			||||||
        num_inputs = len(x_np)
 | 
					        num_inputs = len(x_np)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -233,8 +233,12 @@ def convert_llm(model: torch.nn.Module,
 | 
				
			||||||
            model.num_layers = layer_num
 | 
					            model.num_layers = layer_num
 | 
				
			||||||
            model.transpose_value_cache = transpose_value_cache
 | 
					            model.transpose_value_cache = transpose_value_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"):
 | 
				
			||||||
 | 
					                model_type = "llama"
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                model_type = "llama_32"
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                res = InitLLMPipeline("llama", kv_len, model.num_head, model.head_dim, layer_num,
 | 
					                res = InitLLMPipeline(model_type, kv_len, model.num_head, model.head_dim, layer_num,
 | 
				
			||||||
                                      model.vocab_size, weight_dir, "model",
 | 
					                                      model.vocab_size, weight_dir, "model",
 | 
				
			||||||
                                      first_blob_path, last_blob_path,
 | 
					                                      first_blob_path, last_blob_path,
 | 
				
			||||||
                                      os.path.join(temp_dir, "decoder_layer"), layernorm_const)
 | 
					                                      os.path.join(temp_dir, "decoder_layer"), layernorm_const)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,6 +19,68 @@ import torch
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead
 | 
					from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead
 | 
				
			||||||
 | 
					from intel_npu_acceleration_library.backend.factory import NNFactory
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Llama32Embedding(NNFactory):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        vocab_size,
 | 
				
			||||||
 | 
					        embedding_dim,
 | 
				
			||||||
 | 
					        embedding_weight,
 | 
				
			||||||
 | 
					        padding_idx,
 | 
				
			||||||
 | 
					        inv_freq,
 | 
				
			||||||
 | 
					        attention_scaling,
 | 
				
			||||||
 | 
					        dtype,  # fp16
 | 
				
			||||||
 | 
					        device: str = "NPU",
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__(False, device)
 | 
				
			||||||
 | 
					        self.vocab_size = vocab_size
 | 
				
			||||||
 | 
					        self.embedding_dim = embedding_dim
 | 
				
			||||||
 | 
					        self.padding_idx = padding_idx
 | 
				
			||||||
 | 
					        self.attention_scaling = attention_scaling
 | 
				
			||||||
 | 
					        self.dtype = dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # define input
 | 
				
			||||||
 | 
					        weight = self.constant(embedding_weight)
 | 
				
			||||||
 | 
					        input = self.parameter((1, 1), dtype=np.int32)
 | 
				
			||||||
 | 
					        position_ids = self.parameter((1, 1), dtype=np.int64)
 | 
				
			||||||
 | 
					        inv_freq = self.constant(inv_freq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # embed_tokens module
 | 
				
			||||||
 | 
					        if padding_idx == -1:
 | 
				
			||||||
 | 
					            padding_idx += vocab_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        axis_node = self.constant(np.array([0], dtype=np.int64))
 | 
				
			||||||
 | 
					        if padding_idx is not None:
 | 
				
			||||||
 | 
					            masked_embeddings = np.ones(weight.shape, dtype=np.float16)
 | 
				
			||||||
 | 
					            masked_embeddings[padding_idx, :] = 0.0  # mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            node_mask = self.constant(masked_embeddings)
 | 
				
			||||||
 | 
					            node_masked_w = self.eltwise_mul(weight, node_mask)
 | 
				
			||||||
 | 
					            res = self.gather(node_masked_w, input, axis_node, 0)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            res = self.gather(weight, input, axis_node, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # rotary_emb module
 | 
				
			||||||
 | 
					        inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1))
 | 
				
			||||||
 | 
					        position_ids = self.reshape(position_ids, (1, 1, 1))
 | 
				
			||||||
 | 
					        freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq),
 | 
				
			||||||
 | 
					                                 self.convert_to_fp32(position_ids))
 | 
				
			||||||
 | 
					        freqs = self.transpose(freqs, [0, 2, 1])
 | 
				
			||||||
 | 
					        emb = self.concat(freqs, freqs, axis=2)
 | 
				
			||||||
 | 
					        cos = self.cos(emb)
 | 
				
			||||||
 | 
					        sin = self.sin(emb)
 | 
				
			||||||
 | 
					        cos = cos * self.attention_scaling
 | 
				
			||||||
 | 
					        sin = sin * self.attention_scaling
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # define outputs
 | 
				
			||||||
 | 
					        res = self.convert_to_fp16(res)
 | 
				
			||||||
 | 
					        cos = self.convert_to_fp32(cos)
 | 
				
			||||||
 | 
					        sin = self.convert_to_fp32(sin)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print("start compiling")
 | 
				
			||||||
 | 
					        self.compile()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
					def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
				
			||||||
| 
						 | 
					@ -71,14 +133,27 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
				
			||||||
        bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
 | 
					        bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
 | 
				
			||||||
        weight.tofile(bin_file)
 | 
					        weight.tofile(bin_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    embedding_layer = model.model.embed_tokens
 | 
					    if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"):
 | 
				
			||||||
    new_embedding = LLMEmbedding(
 | 
					        # llama-2-7B & llama-3-8B
 | 
				
			||||||
        vocab_size=model.config.vocab_size,
 | 
					        embedding_layer = model.model.embed_tokens
 | 
				
			||||||
        embedding_dim=model.config.hidden_size,
 | 
					        new_embedding = LLMEmbedding(
 | 
				
			||||||
        embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
 | 
					            vocab_size=model.config.vocab_size,
 | 
				
			||||||
        padding_idx=model.config.pad_token_id,
 | 
					            embedding_dim=model.config.hidden_size,
 | 
				
			||||||
        dtype=np.float16,
 | 
					            embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
 | 
				
			||||||
    )
 | 
					            padding_idx=model.config.pad_token_id,
 | 
				
			||||||
 | 
					            dtype=np.float16,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # llama-3.2-3B & llama-3.2-1B
 | 
				
			||||||
 | 
					        new_embedding = Llama32Embedding(
 | 
				
			||||||
 | 
					            vocab_size=model.config.vocab_size,
 | 
				
			||||||
 | 
					            embedding_dim=model.config.hidden_size,
 | 
				
			||||||
 | 
					            embedding_weight=model.model.embed_tokens.weight.to(torch.float16).detach().numpy(),
 | 
				
			||||||
 | 
					            padding_idx=model.config.pad_token_id,
 | 
				
			||||||
 | 
					            inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16),
 | 
				
			||||||
 | 
					            attention_scaling=model.model.rotary_emb.attention_scaling,
 | 
				
			||||||
 | 
					            dtype=np.float16,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
    first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
 | 
					    first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
 | 
				
			||||||
                                                         temp_dir)
 | 
					                                                         temp_dir)
 | 
				
			||||||
    return first_blob_path, last_blob_path
 | 
					    return first_blob_path, last_blob_path
 | 
				
			||||||
| 
						 | 
					@ -135,8 +210,14 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
            scales.append(l.scale)
 | 
					            scales.append(l.scale)
 | 
				
			||||||
        weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
					        weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
					    if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
				
			||||||
    cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
					        # llama-2-7B & llama-3-8B
 | 
				
			||||||
 | 
					        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
				
			||||||
 | 
					        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # llama-3.2-3B & llama-3.2-1B
 | 
				
			||||||
 | 
					        cached_cos = None
 | 
				
			||||||
 | 
					        cached_sin = None
 | 
				
			||||||
    layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
					    layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
				
			||||||
    layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
 | 
					    layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -168,14 +249,26 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
                                                        f"decoder_layer_{layer_idx}",
 | 
					                                                        f"decoder_layer_{layer_idx}",
 | 
				
			||||||
                                                        temp_dir)
 | 
					                                                        temp_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if layernorm_const:
 | 
					    if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
				
			||||||
        st_idx = 5
 | 
					        # llama-2-7B & llama-3-8B
 | 
				
			||||||
 | 
					        if layernorm_const:
 | 
				
			||||||
 | 
					            st_idx = 5
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
 | 
				
			||||||
 | 
					            post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
				
			||||||
 | 
					            layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
				
			||||||
 | 
					            layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
				
			||||||
 | 
					            st_idx = 7
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
 | 
					        # llama-3.2-3B & llama-3.2-1B
 | 
				
			||||||
        post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
					        if layernorm_const:
 | 
				
			||||||
        layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
					            st_idx = 6
 | 
				
			||||||
        layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
					        else:
 | 
				
			||||||
        st_idx = 7
 | 
					            input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
				
			||||||
 | 
					            post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_5.bin")
 | 
				
			||||||
 | 
					            layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
				
			||||||
 | 
					            layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
				
			||||||
 | 
					            st_idx = 8
 | 
				
			||||||
    for idx, (weight, scale) in enumerate(weights):
 | 
					    for idx, (weight, scale) in enumerate(weights):
 | 
				
			||||||
        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
					        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
				
			||||||
        weight.numpy().tofile(bin_file)
 | 
					        weight.numpy().tofile(bin_file)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue