diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md index 8375b105..a59b9ec2 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md @@ -8,6 +8,7 @@ In this directory, you will find examples on how to directly run HuggingFace `tr |------------|----------------------------------------------------------------| | Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | | Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | +| Qwen2.5 | [Qwen/Qwen2.5-7b-Instruct](https://huggingface.co/Qwen/Qwen2.5-7b-Instruct) | | Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) | | MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16) | @@ -30,7 +31,7 @@ pip install --pre --upgrade ipex-llm[npu] ## 2. Runtime Configurations -**Following envrionment variables are required**: +**Following environment variables are required**: ```cmd set BIGDL_USE_NPU=1 @@ -46,6 +47,9 @@ python llama2.py :: to run Meta-Llama-3-8B-Instruct python llama3.py +:: to run Qwen2.5-7b-Instruct +python qwen.py + :: to run Baichuan2-7B-Chat python baichuan2.py diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py new file mode 100644 index 00000000..c8fd4038 --- /dev/null +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py @@ -0,0 +1,108 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import os +import torch +import time +import argparse +from ipex_llm.transformers.npu_model import AutoModelForCausalLM +from transformers import AutoTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Predict Tokens using `generate()` API for npu model" + ) + parser.add_argument( + "--repo-id-or-model-path", + type=str, + default="Qwen/Qwen2.5-7B-Instruct", # Or Qwen2-7B-Instruct + help="The huggingface repo id for the Baichuan2 model to be downloaded" + ", or the path to the huggingface checkpoint folder", + ) + parser.add_argument("--lowbit-path", type=str, + default="", + help="The path to the lowbit model folder, leave blank if you do not want to save. \ + If path not exists, lowbit model will be saved there. \ + Else, lowbit model will be loaded.", + ) + parser.add_argument('--prompt', type=str, default="AI是什么?", + help='Prompt to infer') + parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") + parser.add_argument("--max-context-len", type=int, default=1024) + parser.add_argument("--max-prompt-len", type=int, default=960) + parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + + if not args.lowbit_path or not os.path.exists(args.lowbit_path): + model = AutoModelForCausalLM.from_pretrained(model_path, + optimize_model=True, + pipeline=True, + max_context_len=args.max_context_len, + max_prompt_len=args.max_prompt_len, + torch_dtype=torch.float16, + attn_implementation="eager", + transpose_value_cache=not args.disable_transpose_value_cache, + mixed_precision=True, + trust_remote_code=True) + else: + model = AutoModelForCausalLM.load_low_bit( + args.lowbit_path, + attn_implementation="eager", + torch_dtype=torch.float16, + max_context_len=args.max_context_len, + max_prompt_len=args.max_prompt_len, + pipeline=True, + transpose_value_cache=not args.disable_transpose_value_cache) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + if args.lowbit_path and not os.path.exists(args.lowbit_path): + model.save_low_bit(args.lowbit_path) + + print("-" * 80) + print("done") + messages = [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": args.prompt}] + text = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + with torch.inference_mode(): + print("finish to load") + for i in range(5): + _input_ids = tokenizer([text], return_tensors="pt").input_ids + print("input length:", len(_input_ids[0])) + st = time.time() + output = model.generate( + _input_ids, max_new_tokens=args.n_predict, do_print=True + ) + end = time.time() + print(f"Inference time: {end-st} s") + input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False) + print("-" * 20, "Input", "-" * 20) + print(input_str) + output_str = tokenizer.decode(output[0], skip_special_tokens=False) + print("-" * 20, "Output", "-" * 20) + print(output_str) + + print("-" * 80) + print("done") + print("success shut down") diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index fd36a499..c10fbfc3 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -267,6 +267,43 @@ def convert_minicpm( convert_forward(model, module.MiniCPMForCausalLM, minicpm_casullm_forward) +def convert_qwen( + model: torch.nn.Module, + max_output_len=1024, + max_prompt_len=1024, + decoder=False, + inter_pp=None, + intra_pp=None, + transpose_value_cache=True, +): + from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward + from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner + from transformers.models.qwen2.modeling_qwen2 import Qwen2Model + if decoder: + decode_runner = DecodeRunner( + model, + max_seq_len=max_output_len, + inter_pp=inter_pp, + intra_pp=intra_pp, + transpose_value_cache=transpose_value_cache, + ) + else: + decode_runner = None + prefill_runner = PrefillRunner( + model, + max_output_len=max_output_len, + max_prompt_len=max_prompt_len, + transpose_value_cache=transpose_value_cache, + ) + qwen2_model_forward = gen_qwen2_fused_model_forward( + prefill_runner=prefill_runner, decode_runner=decode_runner + ) + convert_forward(model, Qwen2Model, qwen2_model_forward) + from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM + from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward + convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward) + + def optimize_llm( model: torch.nn.Module, max_context_len=1024, @@ -300,31 +337,13 @@ def optimize_llm( inter_pp = 2 else: inter_pp = 1 - - from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward - from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner - from transformers.models.qwen2.modeling_qwen2 import Qwen2Model - - decode_runner = DecodeRunner( - model, - max_seq_len=max_context_len, - inter_pp=inter_pp, - intra_pp=intra_pp, - transpose_value_cache=transpose_value_cache, - ) - prefill_runner = PrefillRunner( - model, - max_output_len=max_context_len, - max_prompt_len=max_prompt_len, - transpose_value_cache=transpose_value_cache, - ) - qwen2_model_forward = gen_qwen2_fused_model_forward( - prefill_runner=prefill_runner, decode_runner=decode_runner - ) - convert_forward(model, Qwen2Model, qwen2_model_forward) - from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM - from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward - convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward) + convert_qwen(model, + max_output_len=max_context_len, + max_prompt_len=max_prompt_len, + inter_pp=inter_pp, + intra_pp=intra_pp, + decoder=True, + transpose_value_cache=transpose_value_cache) elif model.config.model_type == "minicpm": # for minicpm-1b if intra_pp is None: diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index 9ad99947..88da9e0e 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -140,31 +140,13 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory): # Self Attention if mode == "decode": - attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1)) + attention_mask = self.create_input_op( + (self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.int64) else: - attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len)) + attention_mask = self.create_input_op( + (self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.int64) - position_ids = self.create_input_op((self.batch_size, self.seq_len)) - past_keys = [] - past_values = [] - if mode == "decode": - for i in range(num_layers): - past_key = self.create_cache_op( - (self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim) - ) - if transpose_value: - past_value = self.create_cache_op( - (self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len) - ) - else: - past_value = self.create_cache_op( - (self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim) - ) - past_keys.append(past_key) - past_values.append(past_value) - else: - past_keys = [None] * num_layers - past_values = [None] * num_layers + position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) if input_layernorm_weights is None: input_layernorm_weights = [] @@ -203,6 +185,27 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory): k_biases = [self.constant(w) for w in k_biases] v_biases = [self.constant(w) for w in v_biases] + past_keys = [] + past_values = [] + if mode == "decode": + for i in range(num_layers): + past_key = self.create_cache_op( + (self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim) + ) + if transpose_value: + past_value = self.create_cache_op( + (self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len) + ) + else: + past_value = self.create_cache_op( + (self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim) + ) + past_keys.append(past_key) + past_values.append(past_value) + else: + past_keys = [None] * num_layers + past_values = [None] * num_layers + hidden_states = input curr_key_values = [] @@ -396,8 +399,8 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module): inputs = ( hidden_states.to(torch.float16), - attention_mask, - position_ids.to(torch.float16), + attention_mask.to(torch.int64), + position_ids.to(torch.int64), ) for i in range(self.intra_stages): @@ -514,7 +517,9 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module): seq_len = hidden_states.shape[1] backend_cls = self.backend_cls_prefill - inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16)) + inputs = (hidden_states.to(torch.float16), + attention_mask.to(torch.int64), + position_ids.to(torch.int64)) inputs += (self.layer_norm_0, self.layer_norm_1) inputs += (self.q_bias, self.k_bias, self.v_bias) hidden_states, past_key, past_value = run_model( @@ -687,9 +692,9 @@ def run_decode( causal_mask[:, :, :, -1] = torch.finfo(torch.float16).min pad_mask = (0, pad_len) padded_causal_mask = F.pad( - causal_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min + causal_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min ) - padded_causal_mask[:, :, :, -1] = 0.0 + padded_causal_mask[:, :, :, -1] = 0 dist.recv(hidden_states, src=rank - 1) layer_outputs = multi_decoder( hidden_states, @@ -973,9 +978,9 @@ class PrefillRunner: hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0) position_ids = F.pad(position_ids, (0, pad_len), value=0) attention_mask = F.pad( - attention_mask.to(torch.float16), + attention_mask.to(torch.int64), (0, pad_len, 0, pad_len), - value=torch.finfo(torch.float16).min, + value=torch.iinfo(torch.int64).min, ) args = (hidden_states, position_ids, attention_mask, past_key_value) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 48b7a5e4..21603467 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -196,7 +196,7 @@ def convert_llm(model: torch.nn.Module, group_size: int): if group_size == 0: n_splits_linear = 1 - n_splits_down_proj = 1 + n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1 else: n_splits_linear = model.config.hidden_size // group_size n_splits_down_proj = model.config.intermediate_size // group_size @@ -318,9 +318,49 @@ def convert_llm(model: torch.nn.Module, except: invalidInputError(False, "False to InitLLMPipeline.") + elif model.config.model_type == "qwen2": + with tempfile.TemporaryDirectory() as temp_dir: + weight_dir = os.path.join(temp_dir, "model_weights") + os.mkdir(weight_dir) + layer_num = len(model.model.layers) + from .qwen import convert_qwen_layer, convert_lm_head_and_embedding + first_blob_path, last_blob_path = convert_lm_head_and_embedding(model, n_splits_linear, + temp_dir, weight_dir) + + param_list = [] + for layer_idx in range(0, layer_num): + param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj, + temp_dir, weight_dir, transpose_value_cache, kv_len, group_size)) + with Pool() as pool: + result = pool.starmap(convert_qwen_layer, param_list) + + # Prefill Runner + from ipex_llm.transformers.npu_models.convert_mp import convert_qwen + convert_qwen(model, + max_output_len=kv_len, + max_prompt_len=max_prompt_len, + decoder=False, + transpose_value_cache=transpose_value_cache) + + # patch attrs for generate + model.kv_len = kv_len + model.num_head = model.model.layers[0].self_attn.num_key_value_heads + model.head_dim = model.model.layers[0].self_attn.head_dim + model.num_layers = layer_num + model.transpose_value_cache = transpose_value_cache + model.vocab_size = model.config.vocab_size + + try: + res = InitLLMPipeline("qwen", kv_len, model.num_head, model.head_dim, layer_num, + model.vocab_size, weight_dir, "model", + first_blob_path, last_blob_path, + os.path.join(temp_dir, "decoder_layer")) + except: + invalidInputError(False, + "False to InitLLMPipeline.") else: - invalidInputError(False, - "Now we only support Llama2 / Llama3 / Baichuan2 for pipeline running.") + invalidInputError(False, "Now we only support Llama2 / Llama3 / Baichuan2 / " + "Qwen2 / Qwen2.5 / Minicpm for pipeline running.") if isinstance(model.lm_head, SlicedLMHead): model.lm_head.get_fused_lm_head() diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py new file mode 100644 index 00000000..1d514835 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py @@ -0,0 +1,187 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import torch +import numpy as np +import os +from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead + + +def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): + num_heads = model.model.layers[0].self_attn.num_heads + head_dim = model.model.layers[0].self_attn.head_dim + rms_norm_eps = model.config.rms_norm_eps + vocab_size = model.config.vocab_size + model_norm = model.model.norm + lm_heads = model.lm_head.lm_heads # Qwen2 is always SlicedLMHead + if n_splits_linear == 1: + weights = [(lm_heads[0].weight, lm_heads[0].scale)] + else: + lm_head_weights = [] + scales = [] + for i in range(n_splits_linear): + lm_head_weights.append(lm_heads[i].weight) + scales.append(lm_heads[i].scale) + weights = [(torch.stack(lm_head_weights, axis=0), + torch.stack(scales, axis=0))] + if isinstance(weights[0], tuple): + np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 + else: # FP16 Linear + np_dtype = np.float16 + + new_lm_head = LowBitLLMLMHead( + [1, 1, num_heads * head_dim], + num_heads=num_heads, + max_seq_len=1, # seems doesn't matter + rms_norm_eps=rms_norm_eps, + mode="decode", + transpose_value=False, # seems doesn't matter + dtype=np_dtype, + model_norm_weight=model_norm.weight.to(torch.float16), + vocab_size=vocab_size, + n_splits=n_splits_linear + ) + last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir) + + # save weights bins files + if n_splits_linear == 1: + weight_numpy = [ + lm_heads[0].weight.data.numpy(), lm_heads[0].scale.data.numpy(), + ] + else: + weight_numpy = [v.numpy() for v in weights[0]] + + for idx, weight in enumerate(weight_numpy): + bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin") + weight.tofile(bin_file) + + embedding_layer = model.model.embed_tokens + new_embedding = LLMEmbedding( + vocab_size=model.config.vocab_size, + embedding_dim=model.config.hidden_size, + embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), + padding_idx=model.config.pad_token_id, + dtype=np.float16, + ) + first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", + temp_dir) + return first_blob_path, last_blob_path + + +def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, + temp_dir, weight_dir, transpose_value_cache, kv_len, group_size): + num_heads = model.model.layers[0].self_attn.num_heads + num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads + head_dim = model.model.layers[0].self_attn.head_dim + intermediate_size = model.config.intermediate_size + rms_norm_eps = model.config.rms_norm_eps + + from ipex_llm.transformers.npu_models.qwen2_mp import LowBitQwenMultiDecoderlayer + curr_layer = model.model.layers[layer_idx] + attn_layer = curr_layer.self_attn + mlp_layer = curr_layer.mlp + + weights = [] + if n_splits_linear == 1: + for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, + attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, + attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, + mlp_layer.up_proj_dq_list): + weights.append((q.weight, q.scale)) + weights.append((k.weight, k.scale)) + weights.append((v.weight, v.scale)) + weights.append((o.weight, o.scale)) + weights.append((g.weight, g.scale)) + weights.append((u.weight, u.scale)) + else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + + if n_splits_down_proj == 1: + for l in mlp_layer.down_proj_dq_list: + weights.append((l.weight, l.scale)) + else: + l_weights = [] + scales = [] + for l in mlp_layer.down_proj_dq_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + + q_bias = attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16) + k_bias = attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16) + v_bias = attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16) + cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) + cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) + layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) + + if isinstance(weights[0], tuple): + np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 + else: # FP16 Linear + np_dtype = np.float16 + + single_decoder = LowBitQwenMultiDecoderlayer( + [1, 1, num_heads * head_dim], + input_layernorm_weights=[layer_norm_0], + post_attn_layernorm_weights=[layer_norm_1], + q_biases=None, + k_biases=None, + v_biases=None, + cached_cos=cached_cos, + cached_sin=cached_sin, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + num_layers=1, + max_seq_len=kv_len, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + mode="decode", + transpose_value=transpose_value_cache, + dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size + ) + rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, + f"decoder_layer_{layer_idx}", + temp_dir) + + # 0, 1, 2 are input_embed/attention_mask/position_id + q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") + k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") + v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_5.bin") + q_bias.data.numpy().tofile(q_bias_bin_file) + k_bias.data.numpy().tofile(k_bias_bin_file) + v_bias.data.numpy().tofile(v_bias_bin_file) + # 6, 7 are past k/v + for idx, (weight, scale) in enumerate(weights): + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{8+idx*2}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{8+idx*2+1}.bin") + scale.numpy().tofile(bin_file) + + del single_decoder