[NPU] Support qwen models with cos_sin_input=True (#12788)
				
					
				
			This commit is contained in:
		
							parent
							
								
									6ff7faa781
								
							
						
					
					
						commit
						ca1d7b7c2c
					
				
					 5 changed files with 238 additions and 200 deletions
				
			
		| 
						 | 
					@ -98,6 +98,8 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
 | 
				
			||||||
        n_splits_linear: int = 1,
 | 
					        n_splits_linear: int = 1,
 | 
				
			||||||
        n_splits_down_proj: int = 1,
 | 
					        n_splits_down_proj: int = 1,
 | 
				
			||||||
        group_size: int = 0,
 | 
					        group_size: int = 0,
 | 
				
			||||||
 | 
					        cos_len: int = 1,
 | 
				
			||||||
 | 
					        keep_position_ids=True,
 | 
				
			||||||
        asym: bool = False,
 | 
					        asym: bool = False,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        super().__init__(max_seq_len=max_seq_len,
 | 
					        super().__init__(max_seq_len=max_seq_len,
 | 
				
			||||||
| 
						 | 
					@ -114,18 +116,13 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
 | 
				
			||||||
        self.dtype = dtype
 | 
					        self.dtype = dtype
 | 
				
			||||||
        self.cached_cos = cached_cos
 | 
					        self.cached_cos = cached_cos
 | 
				
			||||||
        self.cached_sin = cached_sin
 | 
					        self.cached_sin = cached_sin
 | 
				
			||||||
 | 
					        self.cos_len = cos_len
 | 
				
			||||||
        self.batch_size, self.seq_len, self.hidden_size = hidden_shape
 | 
					        self.batch_size, self.seq_len, self.hidden_size = hidden_shape
 | 
				
			||||||
        self.mode = mode
 | 
					        self.mode = mode
 | 
				
			||||||
        self.rms_norm_eps = rms_norm_eps
 | 
					        self.rms_norm_eps = rms_norm_eps
 | 
				
			||||||
        self.transpose_value = transpose_value
 | 
					        self.transpose_value = transpose_value
 | 
				
			||||||
        self.num_layers = num_layers
 | 
					        self.num_layers = num_layers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cos = self.constant(self.cached_cos)
 | 
					 | 
				
			||||||
        self.cos = self.unsqueeze(cos, axis=0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        sin = self.constant(self.cached_sin)
 | 
					 | 
				
			||||||
        self.sin = self.unsqueeze(sin, axis=0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if mode == "decode":
 | 
					        if mode == "decode":
 | 
				
			||||||
            self.kv_seq_len = self.max_seq_len + 1
 | 
					            self.kv_seq_len = self.max_seq_len + 1
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
| 
						 | 
					@ -148,7 +145,21 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
 | 
				
			||||||
            attention_mask = self.create_input_op(
 | 
					            attention_mask = self.create_input_op(
 | 
				
			||||||
                (self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.float16)
 | 
					                (self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.float16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
					        if self.cached_cos is None:
 | 
				
			||||||
 | 
					            if mode == "prefill" and keep_position_ids:
 | 
				
			||||||
 | 
					                position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
				
			||||||
 | 
					            cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim),
 | 
				
			||||||
 | 
					                                       dtype=np.float32)
 | 
				
			||||||
 | 
					            self.cos = self.convert_to_fp16(cos)
 | 
				
			||||||
 | 
					            sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim),
 | 
				
			||||||
 | 
					                                       dtype=np.float32)
 | 
				
			||||||
 | 
					            self.sin = self.convert_to_fp16(sin)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
				
			||||||
 | 
					            cos = self.constant(self.cached_cos)
 | 
				
			||||||
 | 
					            self.cos = self.unsqueeze(cos, axis=0)
 | 
				
			||||||
 | 
					            sin = self.constant(self.cached_sin)
 | 
				
			||||||
 | 
					            self.sin = self.unsqueeze(sin, axis=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if input_layernorm_weights is None:
 | 
					        if input_layernorm_weights is None:
 | 
				
			||||||
            input_layernorm_weights = []
 | 
					            input_layernorm_weights = []
 | 
				
			||||||
| 
						 | 
					@ -211,11 +222,12 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
 | 
				
			||||||
        hidden_states = input
 | 
					        hidden_states = input
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        curr_key_values = []
 | 
					        curr_key_values = []
 | 
				
			||||||
 | 
					        cos_condition = cached_cos is not None or (mode == "prefill" and keep_position_ids)
 | 
				
			||||||
        for i in range(num_layers):
 | 
					        for i in range(num_layers):
 | 
				
			||||||
            hidden_states, new_key_states, new_value_states = self.build_decoder(
 | 
					            hidden_states, new_key_states, new_value_states = self.build_decoder(
 | 
				
			||||||
                hidden_states=hidden_states,
 | 
					                hidden_states=hidden_states,
 | 
				
			||||||
                attention_mask=attention_mask,
 | 
					                attention_mask=attention_mask,
 | 
				
			||||||
                position_ids=position_ids,
 | 
					                position_ids=position_ids if cos_condition else None,
 | 
				
			||||||
                input_layernorm_weight=input_layernorm_weights[i],
 | 
					                input_layernorm_weight=input_layernorm_weights[i],
 | 
				
			||||||
                post_attention_layernorm_weight=post_attn_layernorm_weights[i],
 | 
					                post_attention_layernorm_weight=post_attn_layernorm_weights[i],
 | 
				
			||||||
                q_bias=q_biases[i],
 | 
					                q_bias=q_biases[i],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -173,6 +173,105 @@ class LLMEmbedding(NNFactory):
 | 
				
			||||||
        self.compile()
 | 
					        self.compile()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Llama32Embedding(NNFactory):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        vocab_size,
 | 
				
			||||||
 | 
					        embedding_dim,
 | 
				
			||||||
 | 
					        embedding_weight,
 | 
				
			||||||
 | 
					        padding_idx,
 | 
				
			||||||
 | 
					        inv_freq,
 | 
				
			||||||
 | 
					        attention_scaling,
 | 
				
			||||||
 | 
					        dtype,  # fp16
 | 
				
			||||||
 | 
					        device: str = "NPU",
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__(False, device)
 | 
				
			||||||
 | 
					        self.vocab_size = vocab_size
 | 
				
			||||||
 | 
					        self.embedding_dim = embedding_dim
 | 
				
			||||||
 | 
					        self.padding_idx = padding_idx
 | 
				
			||||||
 | 
					        self.attention_scaling = attention_scaling
 | 
				
			||||||
 | 
					        self.dtype = dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # define input
 | 
				
			||||||
 | 
					        weight = self.constant(embedding_weight)
 | 
				
			||||||
 | 
					        input = self.parameter((1, 1), dtype=np.int32)
 | 
				
			||||||
 | 
					        position_ids = self.parameter((1, 1), dtype=np.int64)
 | 
				
			||||||
 | 
					        inv_freq = self.constant(inv_freq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # embed_tokens module
 | 
				
			||||||
 | 
					        if padding_idx == -1:
 | 
				
			||||||
 | 
					            padding_idx += vocab_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        axis_node = self.constant(np.array([0], dtype=np.int64))
 | 
				
			||||||
 | 
					        if padding_idx is not None:
 | 
				
			||||||
 | 
					            masked_embeddings = np.ones(weight.shape, dtype=np.float16)
 | 
				
			||||||
 | 
					            masked_embeddings[padding_idx, :] = 0.0  # mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            node_mask = self.constant(masked_embeddings)
 | 
				
			||||||
 | 
					            node_masked_w = self.eltwise_mul(weight, node_mask)
 | 
				
			||||||
 | 
					            res = self.gather(node_masked_w, input, axis_node, 0)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            res = self.gather(weight, input, axis_node, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # rotary_emb module
 | 
				
			||||||
 | 
					        inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1))
 | 
				
			||||||
 | 
					        position_ids = self.reshape(position_ids, (1, 1, 1))
 | 
				
			||||||
 | 
					        freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq),
 | 
				
			||||||
 | 
					                                 self.convert_to_fp32(position_ids))
 | 
				
			||||||
 | 
					        freqs = self.transpose(freqs, [0, 2, 1])
 | 
				
			||||||
 | 
					        emb = self.concat(freqs, freqs, axis=2)
 | 
				
			||||||
 | 
					        cos = self.cos(emb)
 | 
				
			||||||
 | 
					        sin = self.sin(emb)
 | 
				
			||||||
 | 
					        cos = cos * self.attention_scaling
 | 
				
			||||||
 | 
					        sin = sin * self.attention_scaling
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # define outputs
 | 
				
			||||||
 | 
					        res = self.convert_to_fp16(res)
 | 
				
			||||||
 | 
					        cos = self.convert_to_fp32(cos)
 | 
				
			||||||
 | 
					        sin = self.convert_to_fp32(sin)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print("start compiling")
 | 
				
			||||||
 | 
					        self.compile()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Llama32PostEmbedding(NNFactory):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        inv_freq,
 | 
				
			||||||
 | 
					        attention_scaling,
 | 
				
			||||||
 | 
					        input_len: int = 1,
 | 
				
			||||||
 | 
					        device: str = "NPU",
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__(False, device)
 | 
				
			||||||
 | 
					        self.attention_scaling = attention_scaling
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # define input
 | 
				
			||||||
 | 
					        position_ids = self.parameter((1, input_len), dtype=np.int64)
 | 
				
			||||||
 | 
					        inv_freq = self.constant(inv_freq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # rotary_emb module
 | 
				
			||||||
 | 
					        inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1))
 | 
				
			||||||
 | 
					        position_ids = self.reshape(position_ids, (1, 1, input_len))
 | 
				
			||||||
 | 
					        freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq),
 | 
				
			||||||
 | 
					                                 self.convert_to_fp32(position_ids))
 | 
				
			||||||
 | 
					        freqs = self.transpose(freqs, [0, 2, 1])
 | 
				
			||||||
 | 
					        emb = self.concat(freqs, freqs, axis=2)
 | 
				
			||||||
 | 
					        cos = self.cos(emb)
 | 
				
			||||||
 | 
					        sin = self.sin(emb)
 | 
				
			||||||
 | 
					        cos = cos * self.attention_scaling
 | 
				
			||||||
 | 
					        sin = sin * self.attention_scaling
 | 
				
			||||||
 | 
					        if input_len > 1:
 | 
				
			||||||
 | 
					            cos = self.unsqueeze(cos, [1])
 | 
				
			||||||
 | 
					            sin = self.unsqueeze(sin, [1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # define outputs
 | 
				
			||||||
 | 
					        cos = self.convert_to_fp32(cos)
 | 
				
			||||||
 | 
					        sin = self.convert_to_fp32(sin)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print("start compiling")
 | 
				
			||||||
 | 
					        self.compile()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def obtain_weight_from_single_layer(attn_layer, mlp_layer):
 | 
					def obtain_weight_from_single_layer(attn_layer, mlp_layer):
 | 
				
			||||||
    weights = []
 | 
					    weights = []
 | 
				
			||||||
    if hasattr(attn_layer, "q_proj_dq_list"):
 | 
					    if hasattr(attn_layer, "q_proj_dq_list"):
 | 
				
			||||||
| 
						 | 
					@ -216,3 +315,65 @@ def obtain_qkv_bias_from_single_layer(attn_layer):
 | 
				
			||||||
        k_bias = attn_layer.k_proj.bias.to(torch.float16)
 | 
					        k_bias = attn_layer.k_proj.bias.to(torch.float16)
 | 
				
			||||||
        v_bias = attn_layer.v_proj.bias.to(torch.float16)
 | 
					        v_bias = attn_layer.v_proj.bias.to(torch.float16)
 | 
				
			||||||
    return q_bias, k_bias, v_bias
 | 
					    return q_bias, k_bias, v_bias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def obtain_embedding_from_model(model, convert_model, temp_dir, weight_dir,
 | 
				
			||||||
 | 
					                                max_prompt_len, keep_ir, compile_blob):
 | 
				
			||||||
 | 
					    if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"):
 | 
				
			||||||
 | 
					        # llama-2-7B & llama-3-8B
 | 
				
			||||||
 | 
					        embedding_layer = model.model.embed_tokens
 | 
				
			||||||
 | 
					        new_embedding = LLMEmbedding(
 | 
				
			||||||
 | 
					            vocab_size=model.config.vocab_size,
 | 
				
			||||||
 | 
					            embedding_dim=model.config.hidden_size,
 | 
				
			||||||
 | 
					            embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
 | 
				
			||||||
 | 
					            padding_idx=model.config.pad_token_id,
 | 
				
			||||||
 | 
					            dtype=np.float16,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if convert_model:
 | 
				
			||||||
 | 
					            bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
 | 
				
			||||||
 | 
					            embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
 | 
				
			||||||
 | 
					            first_blob_path = None
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
 | 
				
			||||||
 | 
					                                                                 temp_dir, keep_ir=keep_ir,
 | 
				
			||||||
 | 
					                                                                 compile_blob=compile_blob)
 | 
				
			||||||
 | 
					            os.remove(os.path.join(temp_dir, "embedding.bin"))
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # llama-3.2-3B & llama-3.2-1B
 | 
				
			||||||
 | 
					        # for transformers >= 4.45.0
 | 
				
			||||||
 | 
					        embedding_layer = model.model.embed_tokens
 | 
				
			||||||
 | 
					        new_embedding = Llama32Embedding(
 | 
				
			||||||
 | 
					            vocab_size=model.config.vocab_size,
 | 
				
			||||||
 | 
					            embedding_dim=model.config.hidden_size,
 | 
				
			||||||
 | 
					            embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
 | 
				
			||||||
 | 
					            padding_idx=model.config.pad_token_id,
 | 
				
			||||||
 | 
					            inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16),
 | 
				
			||||||
 | 
					            attention_scaling=model.model.rotary_emb.attention_scaling,
 | 
				
			||||||
 | 
					            dtype=np.float16,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if convert_model:
 | 
				
			||||||
 | 
					            bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
 | 
				
			||||||
 | 
					            embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
 | 
				
			||||||
 | 
					            first_blob_path = None
 | 
				
			||||||
 | 
					            # save embedding post module
 | 
				
			||||||
 | 
					            inv_freq = model.model.rotary_emb.inv_freq.to(torch.float16)
 | 
				
			||||||
 | 
					            attention_scaling = model.model.rotary_emb.attention_scaling
 | 
				
			||||||
 | 
					            embedding_post = Llama32PostEmbedding(inv_freq=inv_freq,
 | 
				
			||||||
 | 
					                                                  attention_scaling=attention_scaling,
 | 
				
			||||||
 | 
					                                                  input_len=1)
 | 
				
			||||||
 | 
					            update_names_of_IR_and_export_blob(embedding_post, "embedding_post",
 | 
				
			||||||
 | 
					                                               temp_dir, keep_ir=keep_ir, compile_blob=compile_blob)
 | 
				
			||||||
 | 
					            embedding_post_prefill = Llama32PostEmbedding(inv_freq=inv_freq,
 | 
				
			||||||
 | 
					                                                          attention_scaling=attention_scaling,
 | 
				
			||||||
 | 
					                                                          input_len=max_prompt_len)
 | 
				
			||||||
 | 
					            update_names_of_IR_and_export_blob(embedding_post_prefill,
 | 
				
			||||||
 | 
					                                               "embedding_post_prefill",
 | 
				
			||||||
 | 
					                                               temp_dir, keep_ir=keep_ir, compile_blob=compile_blob)
 | 
				
			||||||
 | 
					            os.remove(os.path.join(temp_dir, "embedding_post.bin"))
 | 
				
			||||||
 | 
					            os.remove(os.path.join(temp_dir, "embedding_post_prefill.bin"))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
 | 
				
			||||||
 | 
					                                                                 temp_dir, keep_ir=keep_ir,
 | 
				
			||||||
 | 
					                                                                 compile_blob=compile_blob)
 | 
				
			||||||
 | 
					            os.remove(os.path.join(temp_dir, "embedding.bin"))
 | 
				
			||||||
 | 
					    return first_blob_path
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -31,6 +31,7 @@ import tempfile
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead
 | 
					from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead
 | 
				
			||||||
from multiprocessing import Pool
 | 
					from multiprocessing import Pool
 | 
				
			||||||
 | 
					import transformers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def generate(
 | 
					def generate(
 | 
				
			||||||
| 
						 | 
					@ -456,6 +457,8 @@ def convert_llm_for_deploy(model: torch.nn.Module,
 | 
				
			||||||
        custom_object_save(model, save_directory, config=model.config)
 | 
					        custom_object_save(model, save_directory, config=model.config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if model.config.model_type == "qwen2":
 | 
					    if model.config.model_type == "qwen2":
 | 
				
			||||||
 | 
					        cos_sin_input = not hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached")
 | 
				
			||||||
 | 
					        embedding_post = not hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached")
 | 
				
			||||||
        if group_size == 0:
 | 
					        if group_size == 0:
 | 
				
			||||||
            if model.config.hidden_size == 1536:
 | 
					            if model.config.hidden_size == 1536:
 | 
				
			||||||
                # Qwen2-1.5B-Instruct
 | 
					                # Qwen2-1.5B-Instruct
 | 
				
			||||||
| 
						 | 
					@ -476,6 +479,8 @@ def convert_llm_for_deploy(model: torch.nn.Module,
 | 
				
			||||||
                       "use_prefill_sdp": False,
 | 
					                       "use_prefill_sdp": False,
 | 
				
			||||||
                       "weight_num": 7,
 | 
					                       "weight_num": 7,
 | 
				
			||||||
                       "weight_idx": 8,
 | 
					                       "weight_idx": 8,
 | 
				
			||||||
 | 
					                       "embedding_post": embedding_post,
 | 
				
			||||||
 | 
					                       "cos_sin_input": cos_sin_input,
 | 
				
			||||||
                       "n_splits_linear": n_splits_linear,
 | 
					                       "n_splits_linear": n_splits_linear,
 | 
				
			||||||
                       "n_splits_down_proj": n_splits_down_proj,
 | 
					                       "n_splits_down_proj": n_splits_down_proj,
 | 
				
			||||||
                       "lm_head_low_bit": lm_head_low_bit}
 | 
					                       "lm_head_low_bit": lm_head_low_bit}
 | 
				
			||||||
| 
						 | 
					@ -493,8 +498,8 @@ def convert_llm_for_deploy(model: torch.nn.Module,
 | 
				
			||||||
                           group_size, layernorm_const, "prefill",
 | 
					                           group_size, layernorm_const, "prefill",
 | 
				
			||||||
                           keep_ir=keep_ir, compile_blob=compile_blob)
 | 
					                           keep_ir=keep_ir, compile_blob=compile_blob)
 | 
				
			||||||
        # save blob of lmhead and bin of embedding
 | 
					        # save blob of lmhead and bin of embedding
 | 
				
			||||||
        convert_lm_head_and_embedding(model, save_directory, weight_dir,
 | 
					        convert_lm_head_and_embedding(model, save_directory, weight_dir, convert_model=True,
 | 
				
			||||||
                                      convert_model=True, group_size=group_size,
 | 
					                                      group_size=group_size, max_prompt_len=max_prompt_len,
 | 
				
			||||||
                                      keep_ir=keep_ir, compile_blob=compile_blob)
 | 
					                                      keep_ir=keep_ir, compile_blob=compile_blob)
 | 
				
			||||||
    elif model.config.model_type == "llama":
 | 
					    elif model.config.model_type == "llama":
 | 
				
			||||||
        embedding_post = False
 | 
					        embedding_post = False
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,108 +18,8 @@
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead, \
 | 
					from .common import update_names_of_IR_and_export_blob, LowBitLLMLMHead, \
 | 
				
			||||||
    obtain_weight_from_single_layer
 | 
					    obtain_weight_from_single_layer, obtain_embedding_from_model
 | 
				
			||||||
from intel_npu_acceleration_library.backend.factory import NNFactory
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Llama32Embedding(NNFactory):
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        vocab_size,
 | 
					 | 
				
			||||||
        embedding_dim,
 | 
					 | 
				
			||||||
        embedding_weight,
 | 
					 | 
				
			||||||
        padding_idx,
 | 
					 | 
				
			||||||
        inv_freq,
 | 
					 | 
				
			||||||
        attention_scaling,
 | 
					 | 
				
			||||||
        dtype,  # fp16
 | 
					 | 
				
			||||||
        device: str = "NPU",
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        super().__init__(False, device)
 | 
					 | 
				
			||||||
        self.vocab_size = vocab_size
 | 
					 | 
				
			||||||
        self.embedding_dim = embedding_dim
 | 
					 | 
				
			||||||
        self.padding_idx = padding_idx
 | 
					 | 
				
			||||||
        self.attention_scaling = attention_scaling
 | 
					 | 
				
			||||||
        self.dtype = dtype
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # define input
 | 
					 | 
				
			||||||
        weight = self.constant(embedding_weight)
 | 
					 | 
				
			||||||
        input = self.parameter((1, 1), dtype=np.int32)
 | 
					 | 
				
			||||||
        position_ids = self.parameter((1, 1), dtype=np.int64)
 | 
					 | 
				
			||||||
        inv_freq = self.constant(inv_freq)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # embed_tokens module
 | 
					 | 
				
			||||||
        if padding_idx == -1:
 | 
					 | 
				
			||||||
            padding_idx += vocab_size
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        axis_node = self.constant(np.array([0], dtype=np.int64))
 | 
					 | 
				
			||||||
        if padding_idx is not None:
 | 
					 | 
				
			||||||
            masked_embeddings = np.ones(weight.shape, dtype=np.float16)
 | 
					 | 
				
			||||||
            masked_embeddings[padding_idx, :] = 0.0  # mask
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            node_mask = self.constant(masked_embeddings)
 | 
					 | 
				
			||||||
            node_masked_w = self.eltwise_mul(weight, node_mask)
 | 
					 | 
				
			||||||
            res = self.gather(node_masked_w, input, axis_node, 0)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            res = self.gather(weight, input, axis_node, 0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # rotary_emb module
 | 
					 | 
				
			||||||
        inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1))
 | 
					 | 
				
			||||||
        position_ids = self.reshape(position_ids, (1, 1, 1))
 | 
					 | 
				
			||||||
        freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq),
 | 
					 | 
				
			||||||
                                 self.convert_to_fp32(position_ids))
 | 
					 | 
				
			||||||
        freqs = self.transpose(freqs, [0, 2, 1])
 | 
					 | 
				
			||||||
        emb = self.concat(freqs, freqs, axis=2)
 | 
					 | 
				
			||||||
        cos = self.cos(emb)
 | 
					 | 
				
			||||||
        sin = self.sin(emb)
 | 
					 | 
				
			||||||
        cos = cos * self.attention_scaling
 | 
					 | 
				
			||||||
        sin = sin * self.attention_scaling
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # define outputs
 | 
					 | 
				
			||||||
        res = self.convert_to_fp16(res)
 | 
					 | 
				
			||||||
        cos = self.convert_to_fp32(cos)
 | 
					 | 
				
			||||||
        sin = self.convert_to_fp32(sin)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        print("start compiling")
 | 
					 | 
				
			||||||
        self.compile()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Llama32PostEmbedding(NNFactory):
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        inv_freq,
 | 
					 | 
				
			||||||
        attention_scaling,
 | 
					 | 
				
			||||||
        input_len: int = 1,
 | 
					 | 
				
			||||||
        device: str = "NPU",
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        super().__init__(False, device)
 | 
					 | 
				
			||||||
        self.attention_scaling = attention_scaling
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # define input
 | 
					 | 
				
			||||||
        position_ids = self.parameter((1, input_len), dtype=np.int64)
 | 
					 | 
				
			||||||
        inv_freq = self.constant(inv_freq)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # rotary_emb module
 | 
					 | 
				
			||||||
        inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1))
 | 
					 | 
				
			||||||
        position_ids = self.reshape(position_ids, (1, 1, input_len))
 | 
					 | 
				
			||||||
        freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq),
 | 
					 | 
				
			||||||
                                 self.convert_to_fp32(position_ids))
 | 
					 | 
				
			||||||
        freqs = self.transpose(freqs, [0, 2, 1])
 | 
					 | 
				
			||||||
        emb = self.concat(freqs, freqs, axis=2)
 | 
					 | 
				
			||||||
        cos = self.cos(emb)
 | 
					 | 
				
			||||||
        sin = self.sin(emb)
 | 
					 | 
				
			||||||
        cos = cos * self.attention_scaling
 | 
					 | 
				
			||||||
        sin = sin * self.attention_scaling
 | 
					 | 
				
			||||||
        if input_len > 1:
 | 
					 | 
				
			||||||
            cos = self.unsqueeze(cos, [1])
 | 
					 | 
				
			||||||
            sin = self.unsqueeze(sin, [1])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # define outputs
 | 
					 | 
				
			||||||
        cos = self.convert_to_fp32(cos)
 | 
					 | 
				
			||||||
        sin = self.convert_to_fp32(sin)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        print("start compiling")
 | 
					 | 
				
			||||||
        self.compile()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
 | 
					def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
 | 
				
			||||||
| 
						 | 
					@ -197,62 +97,10 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
 | 
				
			||||||
        bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
 | 
					        bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
 | 
				
			||||||
        weight.tofile(bin_file)
 | 
					        weight.tofile(bin_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"):
 | 
					    first_blob_path = obtain_embedding_from_model(model, convert_model,
 | 
				
			||||||
        # llama-2-7B & llama-3-8B
 | 
					                                                  temp_dir, weight_dir,
 | 
				
			||||||
        embedding_layer = model.model.embed_tokens
 | 
					                                                  max_prompt_len,
 | 
				
			||||||
        new_embedding = LLMEmbedding(
 | 
					                                                  keep_ir, compile_blob)
 | 
				
			||||||
            vocab_size=model.config.vocab_size,
 | 
					 | 
				
			||||||
            embedding_dim=model.config.hidden_size,
 | 
					 | 
				
			||||||
            embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
 | 
					 | 
				
			||||||
            padding_idx=model.config.pad_token_id,
 | 
					 | 
				
			||||||
            dtype=np.float16,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        if convert_model:
 | 
					 | 
				
			||||||
            bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
 | 
					 | 
				
			||||||
            embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
 | 
					 | 
				
			||||||
            first_blob_path = None
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
 | 
					 | 
				
			||||||
                                                                 temp_dir, keep_ir=keep_ir,
 | 
					 | 
				
			||||||
                                                                 compile_blob=compile_blob)
 | 
					 | 
				
			||||||
            os.remove(os.path.join(temp_dir, "embedding.bin"))
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        # llama-3.2-3B & llama-3.2-1B
 | 
					 | 
				
			||||||
        embedding_layer = model.model.embed_tokens
 | 
					 | 
				
			||||||
        new_embedding = Llama32Embedding(
 | 
					 | 
				
			||||||
            vocab_size=model.config.vocab_size,
 | 
					 | 
				
			||||||
            embedding_dim=model.config.hidden_size,
 | 
					 | 
				
			||||||
            embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
 | 
					 | 
				
			||||||
            padding_idx=model.config.pad_token_id,
 | 
					 | 
				
			||||||
            inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16),
 | 
					 | 
				
			||||||
            attention_scaling=model.model.rotary_emb.attention_scaling,
 | 
					 | 
				
			||||||
            dtype=np.float16,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        if convert_model:
 | 
					 | 
				
			||||||
            bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
 | 
					 | 
				
			||||||
            embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
 | 
					 | 
				
			||||||
            first_blob_path = None
 | 
					 | 
				
			||||||
            # save embedding post module
 | 
					 | 
				
			||||||
            inv_freq = model.model.rotary_emb.inv_freq.to(torch.float16)
 | 
					 | 
				
			||||||
            attention_scaling = model.model.rotary_emb.attention_scaling
 | 
					 | 
				
			||||||
            embedding_post = Llama32PostEmbedding(inv_freq=inv_freq,
 | 
					 | 
				
			||||||
                                                  attention_scaling=attention_scaling,
 | 
					 | 
				
			||||||
                                                  input_len=1)
 | 
					 | 
				
			||||||
            update_names_of_IR_and_export_blob(embedding_post, "embedding_post",
 | 
					 | 
				
			||||||
                                               temp_dir, keep_ir=keep_ir, compile_blob=compile_blob)
 | 
					 | 
				
			||||||
            embedding_post_prefill = Llama32PostEmbedding(inv_freq=inv_freq,
 | 
					 | 
				
			||||||
                                                          attention_scaling=attention_scaling,
 | 
					 | 
				
			||||||
                                                          input_len=max_prompt_len)
 | 
					 | 
				
			||||||
            update_names_of_IR_and_export_blob(embedding_post_prefill,
 | 
					 | 
				
			||||||
                                               "embedding_post_prefill",
 | 
					 | 
				
			||||||
                                               temp_dir, keep_ir=keep_ir, compile_blob=compile_blob)
 | 
					 | 
				
			||||||
            os.remove(os.path.join(temp_dir, "embedding_post.bin"))
 | 
					 | 
				
			||||||
            os.remove(os.path.join(temp_dir, "embedding_post_prefill.bin"))
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
 | 
					 | 
				
			||||||
                                                                 temp_dir, keep_ir=keep_ir,
 | 
					 | 
				
			||||||
                                                                 compile_blob=compile_blob)
 | 
					 | 
				
			||||||
            os.remove(os.path.join(temp_dir, "embedding.bin"))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return first_blob_path, last_blob_path
 | 
					    return first_blob_path, last_blob_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,13 +18,14 @@
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead, \
 | 
					from .common import update_names_of_IR_and_export_blob, LowBitLLMLMHead, \
 | 
				
			||||||
    obtain_weight_from_single_layer, obtain_qkv_bias_from_single_layer
 | 
					    obtain_weight_from_single_layer, obtain_qkv_bias_from_single_layer, \
 | 
				
			||||||
 | 
					    obtain_embedding_from_model
 | 
				
			||||||
from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead
 | 
					from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_lm_head_and_embedding(model, temp_dir, weight_dir,
 | 
					def convert_lm_head_and_embedding(model, temp_dir, weight_dir,
 | 
				
			||||||
                                  convert_model=False, group_size=0,
 | 
					                                  convert_model=False, group_size=0, max_prompt_len=1,
 | 
				
			||||||
                                  keep_ir=False, compile_blob=True):
 | 
					                                  keep_ir=False, compile_blob=True):
 | 
				
			||||||
    num_heads = model.model.layers[0].self_attn.num_heads
 | 
					    num_heads = model.model.layers[0].self_attn.num_heads
 | 
				
			||||||
    head_dim = model.model.layers[0].self_attn.head_dim
 | 
					    head_dim = model.model.layers[0].self_attn.head_dim
 | 
				
			||||||
| 
						 | 
					@ -107,24 +108,10 @@ def convert_lm_head_and_embedding(model, temp_dir, weight_dir,
 | 
				
			||||||
        bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
 | 
					        bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
 | 
				
			||||||
        weight.tofile(bin_file)
 | 
					        weight.tofile(bin_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    embedding_layer = model.model.embed_tokens
 | 
					    first_blob_path = obtain_embedding_from_model(model, convert_model,
 | 
				
			||||||
    new_embedding = LLMEmbedding(
 | 
					                                                  temp_dir, weight_dir,
 | 
				
			||||||
        vocab_size=model.config.vocab_size,
 | 
					                                                  max_prompt_len,
 | 
				
			||||||
        embedding_dim=model.config.hidden_size,
 | 
					                                                  keep_ir, compile_blob)
 | 
				
			||||||
        embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
 | 
					 | 
				
			||||||
        padding_idx=model.config.pad_token_id,
 | 
					 | 
				
			||||||
        dtype=np.float16,
 | 
					 | 
				
			||||||
        input_length=1,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    if convert_model:
 | 
					 | 
				
			||||||
        bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
 | 
					 | 
				
			||||||
        embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
 | 
					 | 
				
			||||||
        first_blob_path = True
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        first_blob_path = update_names_of_IR_and_export_blob(new_embedding, f"embedding",
 | 
					 | 
				
			||||||
                                                             temp_dir, keep_ir=keep_ir,
 | 
					 | 
				
			||||||
                                                             compile_blob=compile_blob)
 | 
					 | 
				
			||||||
        os.remove(os.path.join(temp_dir, "embedding.bin"))
 | 
					 | 
				
			||||||
    return first_blob_path, last_blob_path
 | 
					    return first_blob_path, last_blob_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -145,8 +132,13 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
    mlp_layer = curr_layer.mlp
 | 
					    mlp_layer = curr_layer.mlp
 | 
				
			||||||
    weights = obtain_weight_from_single_layer(attn_layer, mlp_layer)
 | 
					    weights = obtain_weight_from_single_layer(attn_layer, mlp_layer)
 | 
				
			||||||
    q_bias, k_bias, v_bias = obtain_qkv_bias_from_single_layer(attn_layer)
 | 
					    q_bias, k_bias, v_bias = obtain_qkv_bias_from_single_layer(attn_layer)
 | 
				
			||||||
    cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
					    if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
				
			||||||
    cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
					        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
				
			||||||
 | 
					        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # transformers >= 4.45.0
 | 
				
			||||||
 | 
					        cached_cos = None
 | 
				
			||||||
 | 
					        cached_sin = None
 | 
				
			||||||
    layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
					    layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
				
			||||||
    layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
 | 
					    layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -158,10 +150,12 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
    if mode == "decode":
 | 
					    if mode == "decode":
 | 
				
			||||||
        input_len = 1
 | 
					        input_len = 1
 | 
				
			||||||
        decoder_name = f"decoder_layer_{layer_idx}"
 | 
					        decoder_name = f"decoder_layer_{layer_idx}"
 | 
				
			||||||
 | 
					        keep_position_ids = True
 | 
				
			||||||
        npu_dpu_groups = None
 | 
					        npu_dpu_groups = None
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        input_len = kv_len
 | 
					        input_len = kv_len
 | 
				
			||||||
        decoder_name = "decoder_layer_prefill"
 | 
					        decoder_name = "decoder_layer_prefill"
 | 
				
			||||||
 | 
					        keep_position_ids = False
 | 
				
			||||||
        npu_dpu_groups = 6
 | 
					        npu_dpu_groups = 6
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    single_decoder = LowBitQwenMultiDecoderlayer(
 | 
					    single_decoder = LowBitQwenMultiDecoderlayer(
 | 
				
			||||||
| 
						 | 
					@ -185,6 +179,8 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
        n_splits_linear=n_splits_linear,
 | 
					        n_splits_linear=n_splits_linear,
 | 
				
			||||||
        n_splits_down_proj=n_splits_down_proj,
 | 
					        n_splits_down_proj=n_splits_down_proj,
 | 
				
			||||||
        group_size=group_size,
 | 
					        group_size=group_size,
 | 
				
			||||||
 | 
					        cos_len=input_len,
 | 
				
			||||||
 | 
					        keep_position_ids=keep_position_ids,
 | 
				
			||||||
        asym=asym
 | 
					        asym=asym
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
 | 
					    rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
 | 
				
			||||||
| 
						 | 
					@ -196,14 +192,25 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # 0, 1, 2 are input_embed/attention_mask/position_id
 | 
					    # 0, 1, 2 are input_embed/attention_mask/position_id
 | 
				
			||||||
    if mode == "decode":
 | 
					    if mode == "decode":
 | 
				
			||||||
        if layernorm_const:
 | 
					        if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
				
			||||||
            st_idx = 3
 | 
					            if layernorm_const:
 | 
				
			||||||
 | 
					                st_idx = 3
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
 | 
				
			||||||
 | 
					                post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
				
			||||||
 | 
					                layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
				
			||||||
 | 
					                layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
				
			||||||
 | 
					                st_idx = 5
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
 | 
					            # transformers >= 4.45.0
 | 
				
			||||||
            post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
					            if layernorm_const:
 | 
				
			||||||
            layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
					                st_idx = 4
 | 
				
			||||||
            layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
					            else:
 | 
				
			||||||
            st_idx = 5
 | 
					                input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
				
			||||||
 | 
					                post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_5.bin")
 | 
				
			||||||
 | 
					                layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
				
			||||||
 | 
					                layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
				
			||||||
 | 
					                st_idx = 6
 | 
				
			||||||
        q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx}.bin")
 | 
					        q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx}.bin")
 | 
				
			||||||
        k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+1}.bin")
 | 
					        k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+1}.bin")
 | 
				
			||||||
        v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+2}.bin")
 | 
					        v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+2}.bin")
 | 
				
			||||||
| 
						 | 
					@ -261,8 +268,13 @@ def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down
 | 
				
			||||||
            attn_layer = curr_layer.self_attn
 | 
					            attn_layer = curr_layer.self_attn
 | 
				
			||||||
            mlp_layer = curr_layer.mlp
 | 
					            mlp_layer = curr_layer.mlp
 | 
				
			||||||
            weights = obtain_weight_from_single_layer(attn_layer, mlp_layer)
 | 
					            weights = obtain_weight_from_single_layer(attn_layer, mlp_layer)
 | 
				
			||||||
            cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
					            if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
				
			||||||
            cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
					                cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
				
			||||||
 | 
					                cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # transformers >= 4.45.0
 | 
				
			||||||
 | 
					                cached_cos = None
 | 
				
			||||||
 | 
					                cached_sin = None
 | 
				
			||||||
            layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
					            layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
				
			||||||
            layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
 | 
					            layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue