[NPU L0] Add layernorm weight as const / input setting (#12322)
This commit is contained in:
		
							parent
							
								
									a01371f90b
								
							
						
					
					
						commit
						5ee6f97d6f
					
				
					 6 changed files with 80 additions and 38 deletions
				
			
		| 
						 | 
					@ -70,7 +70,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
					def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
                           temp_dir, weight_dir, transpose_value_cache, kv_len, group_size):
 | 
					                           temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
 | 
				
			||||||
 | 
					                           layernorm_const):
 | 
				
			||||||
    num_heads = model.model.layers[0].self_attn.num_heads
 | 
					    num_heads = model.model.layers[0].self_attn.num_heads
 | 
				
			||||||
    head_dim = model.model.layers[0].self_attn.head_dim
 | 
					    head_dim = model.model.layers[0].self_attn.head_dim
 | 
				
			||||||
    intermediate_size = model.config.intermediate_size
 | 
					    intermediate_size = model.config.intermediate_size
 | 
				
			||||||
| 
						 | 
					@ -106,8 +107,8 @@ def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    single_decoder = LowBitBaichuanMultiDecoderlayer(
 | 
					    single_decoder = LowBitBaichuanMultiDecoderlayer(
 | 
				
			||||||
        [1, 1, num_heads * head_dim],
 | 
					        [1, 1, num_heads * head_dim],
 | 
				
			||||||
        input_layernorm_weights=[layer_norm_0],
 | 
					        input_layernorm_weights=[layer_norm_0] if layernorm_const else None,
 | 
				
			||||||
        post_attn_layernorm_weights=[layer_norm_1],
 | 
					        post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None,
 | 
				
			||||||
        cached_cos=cached_cos,
 | 
					        cached_cos=cached_cos,
 | 
				
			||||||
        cached_sin=cached_sin,
 | 
					        cached_sin=cached_sin,
 | 
				
			||||||
        num_heads=num_heads,
 | 
					        num_heads=num_heads,
 | 
				
			||||||
| 
						 | 
					@ -123,9 +124,17 @@ def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj
 | 
				
			||||||
                                                        f"decoder_layer_{layer_idx}",
 | 
					                                                        f"decoder_layer_{layer_idx}",
 | 
				
			||||||
                                                        temp_dir)
 | 
					                                                        temp_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if layernorm_const:
 | 
				
			||||||
 | 
					        st_idx = 5
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
 | 
				
			||||||
 | 
					        post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
				
			||||||
 | 
					        layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
				
			||||||
 | 
					        layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
				
			||||||
 | 
					        st_idx = 7
 | 
				
			||||||
    for idx, (weight, scale) in enumerate(weights):
 | 
					    for idx, (weight, scale) in enumerate(weights):
 | 
				
			||||||
        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2}.bin")
 | 
					        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
				
			||||||
        weight.numpy().tofile(bin_file)
 | 
					        weight.numpy().tofile(bin_file)
 | 
				
			||||||
        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2+1}.bin")
 | 
					        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
 | 
				
			||||||
        scale.numpy().tofile(bin_file)
 | 
					        scale.numpy().tofile(bin_file)
 | 
				
			||||||
    del single_decoder
 | 
					    del single_decoder
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -189,6 +189,8 @@ def convert_llm(model: torch.nn.Module,
 | 
				
			||||||
                max_prompt_len: int,
 | 
					                max_prompt_len: int,
 | 
				
			||||||
                transpose_value_cache: bool,
 | 
					                transpose_value_cache: bool,
 | 
				
			||||||
                group_size: int):
 | 
					                group_size: int):
 | 
				
			||||||
 | 
					    # whether to set layernorm weight as const
 | 
				
			||||||
 | 
					    layernorm_const = os.environ.get("IPEX_LLM_LAYERNORM_CONST", "1") == "1"
 | 
				
			||||||
    if group_size == 0:
 | 
					    if group_size == 0:
 | 
				
			||||||
        n_splits_linear = 1
 | 
					        n_splits_linear = 1
 | 
				
			||||||
        n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1
 | 
					        n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1
 | 
				
			||||||
| 
						 | 
					@ -207,7 +209,8 @@ def convert_llm(model: torch.nn.Module,
 | 
				
			||||||
            param_list = []
 | 
					            param_list = []
 | 
				
			||||||
            for layer_idx in range(0, layer_num):
 | 
					            for layer_idx in range(0, layer_num):
 | 
				
			||||||
                param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
					                param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
                                   temp_dir, weight_dir, transpose_value_cache, kv_len, group_size))
 | 
					                                   temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
 | 
				
			||||||
 | 
					                                   layernorm_const))
 | 
				
			||||||
            with Pool() as pool:
 | 
					            with Pool() as pool:
 | 
				
			||||||
                result = pool.starmap(convert_llama_layer, param_list)
 | 
					                result = pool.starmap(convert_llama_layer, param_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -230,7 +233,7 @@ def convert_llm(model: torch.nn.Module,
 | 
				
			||||||
                res = InitLLMPipeline("llama", kv_len, model.num_head, model.head_dim, layer_num,
 | 
					                res = InitLLMPipeline("llama", kv_len, model.num_head, model.head_dim, layer_num,
 | 
				
			||||||
                                      model.vocab_size, weight_dir, "model",
 | 
					                                      model.vocab_size, weight_dir, "model",
 | 
				
			||||||
                                      first_blob_path, last_blob_path,
 | 
					                                      first_blob_path, last_blob_path,
 | 
				
			||||||
                                      os.path.join(temp_dir, "decoder_layer"))
 | 
					                                      os.path.join(temp_dir, "decoder_layer"), layernorm_const)
 | 
				
			||||||
            except:
 | 
					            except:
 | 
				
			||||||
                invalidInputError(False,
 | 
					                invalidInputError(False,
 | 
				
			||||||
                                  "False to InitLLMPipeline.")
 | 
					                                  "False to InitLLMPipeline.")
 | 
				
			||||||
| 
						 | 
					@ -246,7 +249,8 @@ def convert_llm(model: torch.nn.Module,
 | 
				
			||||||
            param_list = []
 | 
					            param_list = []
 | 
				
			||||||
            for layer_idx in range(0, layer_num):
 | 
					            for layer_idx in range(0, layer_num):
 | 
				
			||||||
                param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
					                param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
                                  temp_dir, weight_dir, transpose_value_cache, kv_len, group_size))
 | 
					                                  temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
 | 
				
			||||||
 | 
					                                  layernorm_const))
 | 
				
			||||||
            with Pool() as pool:
 | 
					            with Pool() as pool:
 | 
				
			||||||
                result = pool.starmap(convert_baichuan_layer, param_list)
 | 
					                result = pool.starmap(convert_baichuan_layer, param_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -270,7 +274,7 @@ def convert_llm(model: torch.nn.Module,
 | 
				
			||||||
                res = InitLLMPipeline("baichuan", kv_len, model.num_head, model.head_dim, layer_num,
 | 
					                res = InitLLMPipeline("baichuan", kv_len, model.num_head, model.head_dim, layer_num,
 | 
				
			||||||
                                      model.vocab_size, weight_dir, "model",
 | 
					                                      model.vocab_size, weight_dir, "model",
 | 
				
			||||||
                                      first_blob_path, last_blob_path,
 | 
					                                      first_blob_path, last_blob_path,
 | 
				
			||||||
                                      os.path.join(temp_dir, "decoder_layer"))
 | 
					                                      os.path.join(temp_dir, "decoder_layer"), layernorm_const)
 | 
				
			||||||
            except:
 | 
					            except:
 | 
				
			||||||
                invalidInputError(False,
 | 
					                invalidInputError(False,
 | 
				
			||||||
                                  "False to InitLLMPipeline.")
 | 
					                                  "False to InitLLMPipeline.")
 | 
				
			||||||
| 
						 | 
					@ -286,7 +290,8 @@ def convert_llm(model: torch.nn.Module,
 | 
				
			||||||
            param_list = []
 | 
					            param_list = []
 | 
				
			||||||
            for layer_idx in range(0, layer_num):
 | 
					            for layer_idx in range(0, layer_num):
 | 
				
			||||||
                param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
					                param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
                                   temp_dir, weight_dir, transpose_value_cache, kv_len, group_size))
 | 
					                                   temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
 | 
				
			||||||
 | 
					                                   layernorm_const))
 | 
				
			||||||
            with Pool() as pool:
 | 
					            with Pool() as pool:
 | 
				
			||||||
                result = pool.starmap(convert_minicpm_layer, param_list)
 | 
					                result = pool.starmap(convert_minicpm_layer, param_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -309,11 +314,12 @@ def convert_llm(model: torch.nn.Module,
 | 
				
			||||||
                res = InitLLMPipeline("minicpm", kv_len, model.num_head, model.head_dim, layer_num,
 | 
					                res = InitLLMPipeline("minicpm", kv_len, model.num_head, model.head_dim, layer_num,
 | 
				
			||||||
                                      model.vocab_size, weight_dir, "model",
 | 
					                                      model.vocab_size, weight_dir, "model",
 | 
				
			||||||
                                      first_blob_path, last_blob_path,
 | 
					                                      first_blob_path, last_blob_path,
 | 
				
			||||||
                                      os.path.join(temp_dir, "decoder_layer"))
 | 
					                                      os.path.join(temp_dir, "decoder_layer"), layernorm_const)
 | 
				
			||||||
            except:
 | 
					            except:
 | 
				
			||||||
                invalidInputError(False,
 | 
					                invalidInputError(False,
 | 
				
			||||||
                                  "False to InitLLMPipeline.")
 | 
					                                  "False to InitLLMPipeline.")
 | 
				
			||||||
    elif model.config.model_type == "qwen2":
 | 
					    elif model.config.model_type == "qwen2":
 | 
				
			||||||
 | 
					        layernorm_const = os.environ.get("IPEX_LLM_LAYERNORM_CONST", "0") == "1"
 | 
				
			||||||
        with tempfile.TemporaryDirectory() as temp_dir:
 | 
					        with tempfile.TemporaryDirectory() as temp_dir:
 | 
				
			||||||
            weight_dir = os.path.join(temp_dir, "model_weights")
 | 
					            weight_dir = os.path.join(temp_dir, "model_weights")
 | 
				
			||||||
            os.mkdir(weight_dir)
 | 
					            os.mkdir(weight_dir)
 | 
				
			||||||
| 
						 | 
					@ -325,7 +331,8 @@ def convert_llm(model: torch.nn.Module,
 | 
				
			||||||
            param_list = []
 | 
					            param_list = []
 | 
				
			||||||
            for layer_idx in range(0, layer_num):
 | 
					            for layer_idx in range(0, layer_num):
 | 
				
			||||||
                param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
					                param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
                                  temp_dir, weight_dir, transpose_value_cache, kv_len, group_size))
 | 
					                                  temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
 | 
				
			||||||
 | 
					                                  layernorm_const))
 | 
				
			||||||
            with Pool() as pool:
 | 
					            with Pool() as pool:
 | 
				
			||||||
                result = pool.starmap(convert_qwen_layer, param_list)
 | 
					                result = pool.starmap(convert_qwen_layer, param_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -349,7 +356,7 @@ def convert_llm(model: torch.nn.Module,
 | 
				
			||||||
                res = InitLLMPipeline("qwen", kv_len, model.num_head, model.head_dim, layer_num,
 | 
					                res = InitLLMPipeline("qwen", kv_len, model.num_head, model.head_dim, layer_num,
 | 
				
			||||||
                                      model.vocab_size, weight_dir, "model",
 | 
					                                      model.vocab_size, weight_dir, "model",
 | 
				
			||||||
                                      first_blob_path, last_blob_path,
 | 
					                                      first_blob_path, last_blob_path,
 | 
				
			||||||
                                      os.path.join(temp_dir, "decoder_layer"))
 | 
					                                      os.path.join(temp_dir, "decoder_layer"), layernorm_const)
 | 
				
			||||||
            except:
 | 
					            except:
 | 
				
			||||||
                invalidInputError(False,
 | 
					                invalidInputError(False,
 | 
				
			||||||
                                  "False to InitLLMPipeline.")
 | 
					                                  "False to InitLLMPipeline.")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -85,7 +85,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
					def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
                        temp_dir, weight_dir, transpose_value_cache, kv_len, group_size):
 | 
					                        temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
 | 
				
			||||||
 | 
					                        layernorm_const):
 | 
				
			||||||
    num_heads = model.model.layers[0].self_attn.num_heads
 | 
					    num_heads = model.model.layers[0].self_attn.num_heads
 | 
				
			||||||
    num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
 | 
					    num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
 | 
				
			||||||
    head_dim = model.model.layers[0].self_attn.head_dim
 | 
					    head_dim = model.model.layers[0].self_attn.head_dim
 | 
				
			||||||
| 
						 | 
					@ -146,8 +147,8 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    single_decoder = LowBitLlamaMultiDecoderlayer(
 | 
					    single_decoder = LowBitLlamaMultiDecoderlayer(
 | 
				
			||||||
        [1, 1, num_heads * head_dim],
 | 
					        [1, 1, num_heads * head_dim],
 | 
				
			||||||
        input_layernorm_weights=[layer_norm_0],
 | 
					        input_layernorm_weights=[layer_norm_0] if layernorm_const else None,
 | 
				
			||||||
        post_attn_layernorm_weights=[layer_norm_1],
 | 
					        post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None,
 | 
				
			||||||
        cached_cos=cached_cos,
 | 
					        cached_cos=cached_cos,
 | 
				
			||||||
        cached_sin=cached_sin,
 | 
					        cached_sin=cached_sin,
 | 
				
			||||||
        num_heads=num_heads,
 | 
					        num_heads=num_heads,
 | 
				
			||||||
| 
						 | 
					@ -167,9 +168,17 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
                                                        f"decoder_layer_{layer_idx}",
 | 
					                                                        f"decoder_layer_{layer_idx}",
 | 
				
			||||||
                                                        temp_dir)
 | 
					                                                        temp_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if layernorm_const:
 | 
				
			||||||
 | 
					        st_idx = 5
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
 | 
				
			||||||
 | 
					        post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
				
			||||||
 | 
					        layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
				
			||||||
 | 
					        layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
				
			||||||
 | 
					        st_idx = 7
 | 
				
			||||||
    for idx, (weight, scale) in enumerate(weights):
 | 
					    for idx, (weight, scale) in enumerate(weights):
 | 
				
			||||||
        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2}.bin")
 | 
					        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
				
			||||||
        weight.numpy().tofile(bin_file)
 | 
					        weight.numpy().tofile(bin_file)
 | 
				
			||||||
        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2+1}.bin")
 | 
					        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
 | 
				
			||||||
        scale.numpy().tofile(bin_file)
 | 
					        scale.numpy().tofile(bin_file)
 | 
				
			||||||
    del single_decoder
 | 
					    del single_decoder
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -197,7 +197,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
					def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
                          temp_dir, weight_dir, transpose_value_cache, kv_len, group_size):
 | 
					                          temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
 | 
				
			||||||
 | 
					                          layernorm_const):
 | 
				
			||||||
    num_heads = model.model.layers[0].self_attn.num_heads
 | 
					    num_heads = model.model.layers[0].self_attn.num_heads
 | 
				
			||||||
    num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
 | 
					    num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
 | 
				
			||||||
    head_dim = model.model.layers[0].self_attn.head_dim
 | 
					    head_dim = model.model.layers[0].self_attn.head_dim
 | 
				
			||||||
| 
						 | 
					@ -238,8 +239,8 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    single_decoder = LowBitMinicpmMultiDecoderlayer(
 | 
					    single_decoder = LowBitMinicpmMultiDecoderlayer(
 | 
				
			||||||
        [1, 1, num_heads * head_dim],
 | 
					        [1, 1, num_heads * head_dim],
 | 
				
			||||||
        input_layernorm_weights=[layer_norm_0],
 | 
					        input_layernorm_weights=[layer_norm_0] if layernorm_const else None,
 | 
				
			||||||
        post_attn_layernorm_weights=[layer_norm_1],
 | 
					        post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None,
 | 
				
			||||||
        cached_cos=cached_cos,
 | 
					        cached_cos=cached_cos,
 | 
				
			||||||
        cached_sin=cached_sin,
 | 
					        cached_sin=cached_sin,
 | 
				
			||||||
        num_heads=num_heads,
 | 
					        num_heads=num_heads,
 | 
				
			||||||
| 
						 | 
					@ -258,9 +259,17 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
                                                        f"decoder_layer_{layer_idx}",
 | 
					                                                        f"decoder_layer_{layer_idx}",
 | 
				
			||||||
                                                        temp_dir)
 | 
					                                                        temp_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if layernorm_const:
 | 
				
			||||||
 | 
					        st_idx = 5
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
 | 
				
			||||||
 | 
					        post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
				
			||||||
 | 
					        layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
				
			||||||
 | 
					        layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
				
			||||||
 | 
					        st_idx = 7
 | 
				
			||||||
    for idx, (weight, scale) in enumerate(weights):
 | 
					    for idx, (weight, scale) in enumerate(weights):
 | 
				
			||||||
        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2}.bin")
 | 
					        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
				
			||||||
        weight.numpy().tofile(bin_file)
 | 
					        weight.numpy().tofile(bin_file)
 | 
				
			||||||
        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2+1}.bin")
 | 
					        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
 | 
				
			||||||
        scale.numpy().tofile(bin_file)
 | 
					        scale.numpy().tofile(bin_file)
 | 
				
			||||||
    del single_decoder
 | 
					    del single_decoder
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -43,7 +43,8 @@ _, _lib_path = get_shared_lib_info("pipeline")
 | 
				
			||||||
# Load the library
 | 
					# Load the library
 | 
				
			||||||
_lib = ctypes.cdll.LoadLibrary(_lib_path)
 | 
					_lib = ctypes.cdll.LoadLibrary(_lib_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_lib.InitLLMPipeline.argtypes = [ctypes.c_char_p] + [ctypes.c_int] * 5 + [ctypes.c_char_p] * 5
 | 
					_lib.InitLLMPipeline.argtypes = [ctypes.c_char_p] + [ctypes.c_int] * 5 + \
 | 
				
			||||||
 | 
					    [ctypes.c_char_p] * 5 + [ctypes.c_bool]
 | 
				
			||||||
_lib.InitLLMPipeline.restype = ctypes.c_int
 | 
					_lib.InitLLMPipeline.restype = ctypes.c_int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_lib.generate_serve.argtypes = [ctypes.c_int] * 5 + [ctypes.c_bool] + [ctypes.c_int]
 | 
					_lib.generate_serve.argtypes = [ctypes.c_int] * 5 + [ctypes.c_bool] + [ctypes.c_int]
 | 
				
			||||||
| 
						 | 
					@ -52,11 +53,13 @@ _lib.generate_serve.restype = ctypes.c_int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def InitLLMPipeline(model_type: str, kv_len: int, num_head: int, head_dim: int, num_layers: int,
 | 
					def InitLLMPipeline(model_type: str, kv_len: int, num_head: int, head_dim: int, num_layers: int,
 | 
				
			||||||
                    vocab_size: int, model_weight_dir: str, model_name: str,
 | 
					                    vocab_size: int, model_weight_dir: str, model_name: str,
 | 
				
			||||||
                    first_blob_name: str, last_blob_name: str, rest_blob_name: str):
 | 
					                    first_blob_name: str, last_blob_name: str, rest_blob_name: str,
 | 
				
			||||||
 | 
					                    layernorm_const: bool):
 | 
				
			||||||
    return _lib.InitLLMPipeline(model_type.encode('utf-8'), kv_len, num_head, head_dim, num_layers,
 | 
					    return _lib.InitLLMPipeline(model_type.encode('utf-8'), kv_len, num_head, head_dim, num_layers,
 | 
				
			||||||
                                vocab_size, model_weight_dir.encode('utf-8'),
 | 
					                                vocab_size, model_weight_dir.encode('utf-8'),
 | 
				
			||||||
                                model_name.encode('utf-8'), first_blob_name.encode('utf-8'),
 | 
					                                model_name.encode('utf-8'), first_blob_name.encode('utf-8'),
 | 
				
			||||||
                                last_blob_name.encode('utf-8'), rest_blob_name.encode('utf-8'))
 | 
					                                last_blob_name.encode('utf-8'), rest_blob_name.encode('utf-8'),
 | 
				
			||||||
 | 
					                                layernorm_const)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def generate_serve(kv_len: int, num_head: int, head_dim: int, num_layers: int,
 | 
					def generate_serve(kv_len: int, num_head: int, head_dim: int, num_layers: int,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -86,7 +86,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
					def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
                       temp_dir, weight_dir, transpose_value_cache, kv_len, group_size):
 | 
					                       temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
 | 
				
			||||||
 | 
					                       layernorm_const):
 | 
				
			||||||
    num_heads = model.model.layers[0].self_attn.num_heads
 | 
					    num_heads = model.model.layers[0].self_attn.num_heads
 | 
				
			||||||
    num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
 | 
					    num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
 | 
				
			||||||
    head_dim = model.model.layers[0].self_attn.head_dim
 | 
					    head_dim = model.model.layers[0].self_attn.head_dim
 | 
				
			||||||
| 
						 | 
					@ -149,8 +150,8 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    single_decoder = LowBitQwenMultiDecoderlayer(
 | 
					    single_decoder = LowBitQwenMultiDecoderlayer(
 | 
				
			||||||
        [1, 1, num_heads * head_dim],
 | 
					        [1, 1, num_heads * head_dim],
 | 
				
			||||||
        input_layernorm_weights=None,
 | 
					        input_layernorm_weights=[layer_norm_0] if layernorm_const else None,
 | 
				
			||||||
        post_attn_layernorm_weights=None,
 | 
					        post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None,
 | 
				
			||||||
        q_biases=None,
 | 
					        q_biases=None,
 | 
				
			||||||
        k_biases=None,
 | 
					        k_biases=None,
 | 
				
			||||||
        v_biases=None,
 | 
					        v_biases=None,
 | 
				
			||||||
| 
						 | 
					@ -174,21 +175,25 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
				
			||||||
                                                        temp_dir)
 | 
					                                                        temp_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # 0, 1, 2 are input_embed/attention_mask/position_id
 | 
					    # 0, 1, 2 are input_embed/attention_mask/position_id
 | 
				
			||||||
    input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
 | 
					    if layernorm_const:
 | 
				
			||||||
    post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
					        st_idx = 3
 | 
				
			||||||
    layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
					    else:
 | 
				
			||||||
    layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
					        input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
 | 
				
			||||||
    q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_5.bin")
 | 
					        post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
 | 
				
			||||||
    k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_6.bin")
 | 
					        layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
				
			||||||
    v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_7.bin")
 | 
					        layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
				
			||||||
 | 
					        st_idx = 5
 | 
				
			||||||
 | 
					    q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx}.bin")
 | 
				
			||||||
 | 
					    k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+1}.bin")
 | 
				
			||||||
 | 
					    v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+2}.bin")
 | 
				
			||||||
    q_bias.data.numpy().tofile(q_bias_bin_file)
 | 
					    q_bias.data.numpy().tofile(q_bias_bin_file)
 | 
				
			||||||
    k_bias.data.numpy().tofile(k_bias_bin_file)
 | 
					    k_bias.data.numpy().tofile(k_bias_bin_file)
 | 
				
			||||||
    v_bias.data.numpy().tofile(v_bias_bin_file)
 | 
					    v_bias.data.numpy().tofile(v_bias_bin_file)
 | 
				
			||||||
    # 6, 7 are past k/v
 | 
					    # 6, 7 are past k/v
 | 
				
			||||||
    for idx, (weight, scale) in enumerate(weights):
 | 
					    for idx, (weight, scale) in enumerate(weights):
 | 
				
			||||||
        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{10+idx*2}.bin")
 | 
					        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+5+idx*2}.bin")
 | 
				
			||||||
        weight.numpy().tofile(bin_file)
 | 
					        weight.numpy().tofile(bin_file)
 | 
				
			||||||
        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{10+idx*2+1}.bin")
 | 
					        bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+5+idx*2+1}.bin")
 | 
				
			||||||
        scale.numpy().tofile(bin_file)
 | 
					        scale.numpy().tofile(bin_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    del single_decoder
 | 
					    del single_decoder
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue