[NPU] support asym_int4 for llama (#12556)
* add llama-imatrix * fix bugs in llama.py * style fix
This commit is contained in:
		
							parent
							
								
									d127a8654c
								
							
						
					
					
						commit
						fcb474820d
					
				
					 2 changed files with 124 additions and 32 deletions
				
			
		| 
						 | 
				
			
			@ -72,6 +72,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        group_size: int = 0,
 | 
			
		||||
        cos_len: int = 1,
 | 
			
		||||
        keep_position_ids=True,
 | 
			
		||||
        asym: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(max_seq_len=max_seq_len,
 | 
			
		||||
                         transpose_value=transpose_value,
 | 
			
		||||
| 
						 | 
				
			
			@ -80,7 +81,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
                         device=device,
 | 
			
		||||
                         n_splits_linear=n_splits_linear,
 | 
			
		||||
                         n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                         group_size=group_size)
 | 
			
		||||
                         group_size=group_size,
 | 
			
		||||
                         asym=asym)
 | 
			
		||||
        self.max_seq_len = max_seq_len
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.dtype = dtype
 | 
			
		||||
| 
						 | 
				
			
			@ -278,7 +280,8 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        do_print: bool = False,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
        group_size: int = 0,
 | 
			
		||||
        asym: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -286,8 +289,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
        op_parameters = []
 | 
			
		||||
        for w in parameters:
 | 
			
		||||
            if isinstance(w, tuple):  # from QuantizedLinear
 | 
			
		||||
            if isinstance(w, tuple) and not asym:  # from QuantizedLinear
 | 
			
		||||
                op_parameters.append((w[0].numpy(), w[1].numpy()))
 | 
			
		||||
            elif isinstance(w, tuple) and asym:  # from QuantizedLinear
 | 
			
		||||
                op_parameters.append((w[0].numpy(), w[1].numpy(),  w[2].numpy()))
 | 
			
		||||
            elif w.dtype in [torch.int8, torch.uint8]:    # QuantizedLinear weight
 | 
			
		||||
                op_parameters.append(w.numpy())
 | 
			
		||||
            elif isinstance(w, np.ndarray):     # scale
 | 
			
		||||
| 
						 | 
				
			
			@ -341,7 +346,8 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
                dtype=np_dtype,
 | 
			
		||||
                n_splits_linear=n_splits_linear,
 | 
			
		||||
                n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                group_size=group_size
 | 
			
		||||
                group_size=group_size,
 | 
			
		||||
                asym=asym,
 | 
			
		||||
            )
 | 
			
		||||
            self.backend_decoders.append(decoder)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -427,6 +433,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0,
 | 
			
		||||
        cos_len: int = 1,
 | 
			
		||||
        asym: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.op_parameters = parameters
 | 
			
		||||
| 
						 | 
				
			
			@ -460,6 +467,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size,
 | 
			
		||||
            cos_len=cos_len,
 | 
			
		||||
            asym=asym,
 | 
			
		||||
        )
 | 
			
		||||
        self.layer_norm_0 = layer_norm_0
 | 
			
		||||
        self.layer_norm_1 = layer_norm_1
 | 
			
		||||
| 
						 | 
				
			
			@ -555,6 +563,7 @@ def run_decode(
 | 
			
		|||
    layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
    n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
    n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
    asym = getattr(model.config, "asym", False)
 | 
			
		||||
    for layer_idx in layer_indexs:
 | 
			
		||||
        curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
        attn_layer = curr_layer.self_attn
 | 
			
		||||
| 
						 | 
				
			
			@ -567,10 +576,17 @@ def run_decode(
 | 
			
		|||
                           mlp_layer.down_proj_dq_list]:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            zeros = []
 | 
			
		||||
            for l in layer_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
                if l.zero is not None:
 | 
			
		||||
                    zeros.append(l.zero)
 | 
			
		||||
            if len(zeros):
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
 | 
			
		||||
                                torch.stack(zeros, axis=0)))
 | 
			
		||||
            else:
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
			
		||||
            cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -603,7 +619,8 @@ def run_decode(
 | 
			
		|||
        do_print=False,
 | 
			
		||||
        n_splits_linear=n_splits_linear,
 | 
			
		||||
        n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
        group_size=group_size
 | 
			
		||||
        group_size=group_size,
 | 
			
		||||
        asym=asym,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    dist.barrier()
 | 
			
		||||
| 
						 | 
				
			
			@ -814,6 +831,7 @@ def run_prefill(
 | 
			
		|||
            layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
            n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
            n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
            asym = getattr(model.config, "asym", False)
 | 
			
		||||
            for layer_idx in layer_indexs:
 | 
			
		||||
                curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
                attn_layer = curr_layer.self_attn
 | 
			
		||||
| 
						 | 
				
			
			@ -827,10 +845,18 @@ def run_prefill(
 | 
			
		|||
                                   mlp_layer.down_proj_dq_list]:
 | 
			
		||||
                    l_weights = []
 | 
			
		||||
                    scales = []
 | 
			
		||||
                    zeros = []
 | 
			
		||||
                    for l in layer_list:
 | 
			
		||||
                        l_weights.append(l.weight)
 | 
			
		||||
                        scales.append(l.scale)
 | 
			
		||||
                    weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
                        if l.zero is not None:
 | 
			
		||||
                            zeros.append(l.zero)
 | 
			
		||||
                    if len(zeros):
 | 
			
		||||
                        weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
 | 
			
		||||
                                        torch.stack(zeros, axis=0)))
 | 
			
		||||
                    else:
 | 
			
		||||
                        weights.append((torch.stack(l_weights, axis=0),
 | 
			
		||||
                                        torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
                if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
			
		||||
                    cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -859,6 +885,7 @@ def run_prefill(
 | 
			
		|||
                    n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                    group_size=group_size,
 | 
			
		||||
                    cos_len=cos_len,
 | 
			
		||||
                    asym=asym,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                layer_weights.extend(weights)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -130,17 +130,31 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
 | 
			
		|||
    vocab_size = model.config.vocab_size
 | 
			
		||||
    model_norm = model.model.norm
 | 
			
		||||
    lm_head = model.lm_head
 | 
			
		||||
    asym = getattr(model.config, "asym", False)
 | 
			
		||||
    if n_splits_linear == 1:
 | 
			
		||||
        weights = [(lm_head.weight, lm_head.scale)]
 | 
			
		||||
        asym = lm_head.qtype == "asym_int4_rtn"
 | 
			
		||||
        if asym:
 | 
			
		||||
            weights = [(lm_head.weight, lm_head.scale, lm_head.zero)]
 | 
			
		||||
        else:
 | 
			
		||||
            weights = [(lm_head.weight, lm_head.scale)]
 | 
			
		||||
    else:
 | 
			
		||||
        lm_heads = lm_head.lm_heads
 | 
			
		||||
        asym = lm_heads[0].qtype == "asym_int4_rtn"
 | 
			
		||||
        lm_head_weights = []
 | 
			
		||||
        scales = []
 | 
			
		||||
        for i in range(n_splits_linear):
 | 
			
		||||
            lm_head_weights.append(lm_heads[i].weight)
 | 
			
		||||
            scales.append(lm_heads[i].scale)
 | 
			
		||||
        weights = [(torch.stack(lm_head_weights, axis=0),
 | 
			
		||||
                    torch.stack(scales, axis=0))]
 | 
			
		||||
        zeros = []
 | 
			
		||||
        for l in lm_heads:
 | 
			
		||||
            lm_head_weights.append(l.weight)
 | 
			
		||||
            scales.append(l.scale)
 | 
			
		||||
            if l.zero is not None:
 | 
			
		||||
                zeros.append(l.zero)
 | 
			
		||||
        if len(zeros):
 | 
			
		||||
            weights = [(torch.stack(lm_head_weights, axis=0),
 | 
			
		||||
                        torch.stack(scales, axis=0),
 | 
			
		||||
                        torch.stack(zeros, axis=0))]
 | 
			
		||||
        else:
 | 
			
		||||
            weights = [(torch.stack(lm_head_weights, axis=0),
 | 
			
		||||
                        torch.stack(scales, axis=0))]
 | 
			
		||||
    if isinstance(weights[0], tuple):
 | 
			
		||||
        np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
 | 
			
		||||
    else:  # FP16 Linear
 | 
			
		||||
| 
						 | 
				
			
			@ -156,16 +170,23 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
 | 
			
		|||
        dtype=np_dtype,
 | 
			
		||||
        model_norm_weight=model_norm.weight.to(torch.float16),
 | 
			
		||||
        vocab_size=vocab_size,
 | 
			
		||||
        n_splits=n_splits_linear
 | 
			
		||||
        n_splits=n_splits_linear,
 | 
			
		||||
        asym=asym
 | 
			
		||||
    )
 | 
			
		||||
    last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir,
 | 
			
		||||
                                                        True, False)
 | 
			
		||||
 | 
			
		||||
    # save weights bins files
 | 
			
		||||
    if n_splits_linear == 1:
 | 
			
		||||
        weight_numpy = [
 | 
			
		||||
            lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
 | 
			
		||||
        ]
 | 
			
		||||
        if not asym:
 | 
			
		||||
            weight_numpy = [
 | 
			
		||||
                lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
 | 
			
		||||
            ]
 | 
			
		||||
        else:
 | 
			
		||||
            weight_numpy = [
 | 
			
		||||
                lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
 | 
			
		||||
                lm_head.zero.data.numpy()
 | 
			
		||||
            ]
 | 
			
		||||
    else:
 | 
			
		||||
        weight_numpy = [v.numpy() for v in weights[0]]
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -234,6 +255,7 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
    head_dim = model.model.layers[0].self_attn.head_dim
 | 
			
		||||
    intermediate_size = model.config.intermediate_size
 | 
			
		||||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    asym = getattr(model.config, "asym", False)
 | 
			
		||||
 | 
			
		||||
    from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer
 | 
			
		||||
    curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
| 
						 | 
				
			
			@ -247,10 +269,17 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
                       mlp_layer.down_proj_dq_list]:
 | 
			
		||||
        l_weights = []
 | 
			
		||||
        scales = []
 | 
			
		||||
        zeros = []
 | 
			
		||||
        for l in layer_list:
 | 
			
		||||
            l_weights.append(l.weight)
 | 
			
		||||
            scales.append(l.scale)
 | 
			
		||||
        weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
            if l.zero is not None:
 | 
			
		||||
                zeros.append(l.zero)
 | 
			
		||||
        if len(zeros):
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
 | 
			
		||||
                            torch.stack(zeros, axis=0)))
 | 
			
		||||
        else:
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
    if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
			
		||||
        # llama-2-7B & llama-3-8B
 | 
			
		||||
| 
						 | 
				
			
			@ -299,7 +328,8 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
        n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
        group_size=group_size,
 | 
			
		||||
        cos_len=input_len,
 | 
			
		||||
        keep_position_ids=keep_position_ids
 | 
			
		||||
        keep_position_ids=keep_position_ids,
 | 
			
		||||
        asym=asym
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
 | 
			
		||||
| 
						 | 
				
			
			@ -329,11 +359,24 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
                layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
			
		||||
                layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
			
		||||
                st_idx = 8
 | 
			
		||||
        for idx, (weight, scale) in enumerate(weights):
 | 
			
		||||
            bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
			
		||||
            weight.numpy().tofile(bin_file)
 | 
			
		||||
            bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
 | 
			
		||||
            scale.numpy().tofile(bin_file)
 | 
			
		||||
        if not asym:
 | 
			
		||||
            for idx, (weight, scale) in enumerate(weights):
 | 
			
		||||
                bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
			
		||||
                weight.numpy().tofile(bin_file)
 | 
			
		||||
                bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                        f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
 | 
			
		||||
                scale.numpy().tofile(bin_file)
 | 
			
		||||
        else:
 | 
			
		||||
            for idx, (weight, scale, zero) in enumerate(weights):
 | 
			
		||||
                bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*3}.bin")
 | 
			
		||||
                weight.numpy().tofile(bin_file)
 | 
			
		||||
                bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                        f"model_{layer_idx}_input_{st_idx+idx*3+1}.bin")
 | 
			
		||||
                scale.numpy().tofile(bin_file)
 | 
			
		||||
                bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                        f"model_{layer_idx}_input_{st_idx+idx*3+2}.bin")
 | 
			
		||||
                zero.numpy().tofile(bin_file)
 | 
			
		||||
 | 
			
		||||
        del single_decoder
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -347,6 +390,7 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow
 | 
			
		|||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    layer_num = len(model.model.layers)
 | 
			
		||||
    fused_layer_num = layer_num // fused_layers
 | 
			
		||||
    asym = getattr(model.config, "asym", False)
 | 
			
		||||
 | 
			
		||||
    from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer
 | 
			
		||||
    for i in range(fused_layers):
 | 
			
		||||
| 
						 | 
				
			
			@ -370,10 +414,17 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow
 | 
			
		|||
                               mlp_layer.down_proj_dq_list]:
 | 
			
		||||
                l_weights = []
 | 
			
		||||
                scales = []
 | 
			
		||||
                zeros = []
 | 
			
		||||
                for l in layer_list:
 | 
			
		||||
                    l_weights.append(l.weight)
 | 
			
		||||
                    scales.append(l.scale)
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
                    if l.zero is not None:
 | 
			
		||||
                        zeros.append(l.zero)
 | 
			
		||||
                if len(zeros):
 | 
			
		||||
                    weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
 | 
			
		||||
                                    torch.stack(zeros, axis=0)))
 | 
			
		||||
                else:
 | 
			
		||||
                    weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
            if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
			
		||||
                # llama-2-7B & llama-3-8B
 | 
			
		||||
| 
						 | 
				
			
			@ -397,12 +448,25 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow
 | 
			
		|||
            layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
			
		||||
            st_idx = 5
 | 
			
		||||
            # 6, 7 are past k/v
 | 
			
		||||
            for idx, (weight, scale) in enumerate(weights):
 | 
			
		||||
                bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
			
		||||
                weight.numpy().tofile(bin_file)
 | 
			
		||||
                bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                        f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
 | 
			
		||||
                scale.numpy().tofile(bin_file)
 | 
			
		||||
            if not asym:
 | 
			
		||||
                for idx, (weight, scale) in enumerate(weights):
 | 
			
		||||
                    bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                            f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
			
		||||
                    weight.numpy().tofile(bin_file)
 | 
			
		||||
                    bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                            f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
 | 
			
		||||
                    scale.numpy().tofile(bin_file)
 | 
			
		||||
            else:
 | 
			
		||||
                for idx, (weight, scale, zero) in enumerate(weights):
 | 
			
		||||
                    bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                            f"model_{layer_idx}_input_{st_idx+idx*3}.bin")
 | 
			
		||||
                    weight.numpy().tofile(bin_file)
 | 
			
		||||
                    bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                            f"model_{layer_idx}_input_{st_idx+idx*3+1}.bin")
 | 
			
		||||
                    scale.numpy().tofile(bin_file)
 | 
			
		||||
                    bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                            f"model_{layer_idx}_input_{st_idx+idx*3+2}.bin")
 | 
			
		||||
                    zero.numpy().tofile(bin_file)
 | 
			
		||||
 | 
			
		||||
        if isinstance(weights[0], tuple):
 | 
			
		||||
            np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
 | 
			
		||||
| 
						 | 
				
			
			@ -426,7 +490,8 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow
 | 
			
		|||
            dtype=np_dtype,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
            group_size=group_size,
 | 
			
		||||
            asym=asym
 | 
			
		||||
        )
 | 
			
		||||
        update_names_of_IR_and_export_blob(fused_decoder,
 | 
			
		||||
                                           f"decoder_layer_{i}",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue