Groupwise prefill optimization (#12291)
* except lm_head * remove * support gw lm_head * update * fix * remove run.bat * fix style * support llama3 * slice -> split * remove debug * fix style * add dpu
This commit is contained in:
		
							parent
							
								
									540eaeb12c
								
							
						
					
					
						commit
						70037ad55f
					
				
					 3 changed files with 105 additions and 147 deletions
				
			
		| 
						 | 
					@ -188,7 +188,10 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
				
			||||||
            new_value_states = self.convert_to_fp16(curr_key_values[i][1])
 | 
					            new_value_states = self.convert_to_fp16(curr_key_values[i][1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        print("start compiling")
 | 
					        print("start compiling")
 | 
				
			||||||
        self.compile()
 | 
					        if mode == "prefill":
 | 
				
			||||||
 | 
					            self.compile(npu_dpu_groups=6)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.compile()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def build_decoder(
 | 
					    def build_decoder(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
| 
						 | 
					@ -753,19 +756,40 @@ def run_prefill(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        weights = []
 | 
					        weights = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
					        if n_splits_linear == 1:
 | 
				
			||||||
                           attn_layer.v_proj_dq_list):
 | 
					            for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
				
			||||||
            weights.append((q.weight, q.scale))
 | 
					                                        attn_layer.k_proj_dq_list,
 | 
				
			||||||
            weights.append((k.weight, k.scale))
 | 
					                                        attn_layer.v_proj_dq_list,
 | 
				
			||||||
            weights.append((v.weight, v.scale))
 | 
					                                        attn_layer.o_proj_dq_list,
 | 
				
			||||||
 | 
					                                        mlp_layer.gate_proj_dq_list,
 | 
				
			||||||
 | 
					                                        mlp_layer.up_proj_dq_list):
 | 
				
			||||||
 | 
					                weights.append((q.weight, q.scale))
 | 
				
			||||||
 | 
					                weights.append((k.weight, k.scale))
 | 
				
			||||||
 | 
					                weights.append((v.weight, v.scale))
 | 
				
			||||||
 | 
					                weights.append((o.weight, o.scale))
 | 
				
			||||||
 | 
					                weights.append((g.weight, g.scale))
 | 
				
			||||||
 | 
					                weights.append((u.weight, u.scale))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
				
			||||||
 | 
					                               attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
				
			||||||
 | 
					                               mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
				
			||||||
 | 
					                l_weights = []
 | 
				
			||||||
 | 
					                scales = []
 | 
				
			||||||
 | 
					                for l in layer_list:
 | 
				
			||||||
 | 
					                    l_weights.append(l.weight)
 | 
				
			||||||
 | 
					                    scales.append(l.scale)
 | 
				
			||||||
 | 
					                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for l in attn_layer.o_proj_dq_list:
 | 
					        if n_splits_down_proj == 1:
 | 
				
			||||||
            weights.append((l.weight, l.scale))
 | 
					            for l in mlp_layer.down_proj_dq_list:
 | 
				
			||||||
        for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
 | 
					                weights.append((l.weight, l.scale))
 | 
				
			||||||
            weights.append((g.weight, g.scale))
 | 
					        else:
 | 
				
			||||||
            weights.append((u.weight, u.scale))
 | 
					            l_weights = []
 | 
				
			||||||
        for l in mlp_layer.down_proj_dq_list:
 | 
					            scales = []
 | 
				
			||||||
            weights.append((l.weight, l.scale))
 | 
					            for l in mlp_layer.down_proj_dq_list:
 | 
				
			||||||
 | 
					                l_weights.append(l.weight)
 | 
				
			||||||
 | 
					                scales.append(l.scale)
 | 
				
			||||||
 | 
					            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
					        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
				
			||||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
					        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -165,60 +165,21 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            hidden_states = self.unsqueeze(hidden_states, axis=0)
 | 
					            hidden_states = self.unsqueeze(hidden_states, axis=0)
 | 
				
			||||||
            if mode == "prefill":
 | 
					            query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
 | 
				
			||||||
                query_states_to_concat = []
 | 
					                                                hidden_size, self.n_splits_linear,
 | 
				
			||||||
                key_states_to_concat = []
 | 
					                                                wt_dtype=self.dtype,
 | 
				
			||||||
                value_states_to_concat = []
 | 
					                                                scale_factor=(self.group_size == 0),
 | 
				
			||||||
                for i in range(self.n_splits_linear):
 | 
					                                                is_prefill=(mode == "prefill"))
 | 
				
			||||||
                    sub_hidden_states = self.slice(hidden_states,
 | 
					            key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
 | 
				
			||||||
                                                   begin=[0, 0, i * groupsize],
 | 
					                                              hidden_size, self.n_splits_linear,
 | 
				
			||||||
                                                   end=[1, seq_len, (i + 1) * groupsize])
 | 
					                                              wt_dtype=self.dtype,
 | 
				
			||||||
                    query_states_to_concat.append(
 | 
					                                              scale_factor=(self.group_size == 0),
 | 
				
			||||||
                        self.linear(
 | 
					                                              is_prefill=(mode == "prefill"))
 | 
				
			||||||
                            sub_hidden_states,
 | 
					            value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
 | 
				
			||||||
                            num_heads * head_dim,
 | 
					                                                hidden_size, self.n_splits_linear,
 | 
				
			||||||
                            groupsize,
 | 
					                                                wt_dtype=self.dtype,
 | 
				
			||||||
                            bias=False,
 | 
					                                                scale_factor=(self.group_size == 0),
 | 
				
			||||||
                            wt_dtype=self.dtype,
 | 
					                                                is_prefill=(mode == "prefill"))
 | 
				
			||||||
                            scale_factor=(self.group_size == 0)
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                    key_states_to_concat.append(
 | 
					 | 
				
			||||||
                        self.linear(
 | 
					 | 
				
			||||||
                            sub_hidden_states,
 | 
					 | 
				
			||||||
                            num_key_value_heads * head_dim,
 | 
					 | 
				
			||||||
                            groupsize,
 | 
					 | 
				
			||||||
                            bias=False,
 | 
					 | 
				
			||||||
                            wt_dtype=self.dtype,
 | 
					 | 
				
			||||||
                            scale_factor=(self.group_size == 0)
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                    value_states_to_concat.append(
 | 
					 | 
				
			||||||
                        self.linear(
 | 
					 | 
				
			||||||
                            sub_hidden_states,
 | 
					 | 
				
			||||||
                            num_key_value_heads * head_dim,
 | 
					 | 
				
			||||||
                            groupsize,
 | 
					 | 
				
			||||||
                            bias=False,
 | 
					 | 
				
			||||||
                            wt_dtype=self.dtype,
 | 
					 | 
				
			||||||
                            scale_factor=(self.group_size == 0)
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                query_states = sum(query_states_to_concat)
 | 
					 | 
				
			||||||
                key_states = sum(key_states_to_concat)
 | 
					 | 
				
			||||||
                value_states = sum(value_states_to_concat)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
 | 
					 | 
				
			||||||
                                                    hidden_size, self.n_splits_linear,
 | 
					 | 
				
			||||||
                                                    wt_dtype=self.dtype,
 | 
					 | 
				
			||||||
                                                    scale_factor=(self.group_size == 0))
 | 
					 | 
				
			||||||
                key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
 | 
					 | 
				
			||||||
                                                  hidden_size, self.n_splits_linear,
 | 
					 | 
				
			||||||
                                                  wt_dtype=self.dtype,
 | 
					 | 
				
			||||||
                                                  scale_factor=(self.group_size == 0))
 | 
					 | 
				
			||||||
                value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
 | 
					 | 
				
			||||||
                                                    hidden_size, self.n_splits_linear,
 | 
					 | 
				
			||||||
                                                    wt_dtype=self.dtype,
 | 
					 | 
				
			||||||
                                                    scale_factor=(self.group_size == 0))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if q_bias is not None:
 | 
					        if q_bias is not None:
 | 
				
			||||||
            query_states = query_states + q_bias
 | 
					            query_states = query_states + q_bias
 | 
				
			||||||
| 
						 | 
					@ -296,23 +257,10 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
                attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
 | 
					                attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if mode == "prefill":
 | 
					            attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
 | 
				
			||||||
                attn_output_to_concat = []
 | 
					                                               self.n_splits_linear, wt_dtype=self.dtype,
 | 
				
			||||||
                for i in range(self.n_splits_linear):
 | 
					                                               scale_factor=(self.group_size == 0),
 | 
				
			||||||
                    sub_attn_output = self.slice(attn_output,
 | 
					                                               is_prefill=(mode == "prefill"))
 | 
				
			||||||
                                                 begin=[0, 0, i * groupsize],
 | 
					 | 
				
			||||||
                                                 end=[1, seq_len, (i + 1) * groupsize])
 | 
					 | 
				
			||||||
                    attn_output_to_concat.append(
 | 
					 | 
				
			||||||
                        self.linear(
 | 
					 | 
				
			||||||
                            sub_attn_output, hidden_size, groupsize, bias=False,
 | 
					 | 
				
			||||||
                            wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                attn_output = sum(attn_output_to_concat)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
 | 
					 | 
				
			||||||
                                                   self.n_splits_linear, wt_dtype=self.dtype,
 | 
					 | 
				
			||||||
                                                   scale_factor=(self.group_size == 0))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return attn_output, new_key_states, new_value_states
 | 
					        return attn_output, new_key_states, new_value_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -488,37 +436,14 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
            mm1 = self.eltwise_mul(self.swish(mm1), mm2)  # type: ignore[attr-defined]
 | 
					            mm1 = self.eltwise_mul(self.swish(mm1), mm2)  # type: ignore[attr-defined]
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
 | 
					            invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
 | 
				
			||||||
            if mode == "prefill":
 | 
					            mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
 | 
				
			||||||
                gate_up_groupsize = self.hidden_size // self.n_splits_linear
 | 
					                                       self.n_splits_linear, wt_dtype=self.dtype,
 | 
				
			||||||
                mm1_to_concat = []
 | 
					                                       scale_factor=(self.group_size == 0),
 | 
				
			||||||
                mm2_to_concat = []
 | 
					                                       is_prefill=(mode == "prefill"))
 | 
				
			||||||
                for i in range(self.n_splits_linear):
 | 
					            mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
 | 
				
			||||||
                    sub_hidden_states = self.slice(hidden_states,
 | 
					                                       self.n_splits_linear, wt_dtype=self.dtype,
 | 
				
			||||||
                                                   begin=[0, 0, i * gate_up_groupsize],
 | 
					                                       scale_factor=(self.group_size == 0),
 | 
				
			||||||
                                                   end=[1, seq_len, (i + 1) * gate_up_groupsize])
 | 
					                                       is_prefill=(mode == "prefill"))
 | 
				
			||||||
                    mm1_to_concat.append(
 | 
					 | 
				
			||||||
                        self.linear(
 | 
					 | 
				
			||||||
                            sub_hidden_states, self.intermediate_size, gate_up_groupsize,
 | 
					 | 
				
			||||||
                            bias=False,
 | 
					 | 
				
			||||||
                            wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                    mm2_to_concat.append(
 | 
					 | 
				
			||||||
                        self.linear(
 | 
					 | 
				
			||||||
                            sub_hidden_states, self.intermediate_size, gate_up_groupsize,
 | 
					 | 
				
			||||||
                            bias=False,
 | 
					 | 
				
			||||||
                            wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                mm1 = sum(mm1_to_concat)
 | 
					 | 
				
			||||||
                mm2 = sum(mm2_to_concat)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
 | 
					 | 
				
			||||||
                                           self.n_splits_linear, wt_dtype=self.dtype,
 | 
					 | 
				
			||||||
                                           scale_factor=(self.group_size == 0))
 | 
					 | 
				
			||||||
                mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
 | 
					 | 
				
			||||||
                                           self.n_splits_linear, wt_dtype=self.dtype,
 | 
					 | 
				
			||||||
                                           scale_factor=(self.group_size == 0))
 | 
					 | 
				
			||||||
            mm1 = self.eltwise_mul(self.swish(mm1), mm2)  # type: ignore[attr-defined]
 | 
					            mm1 = self.eltwise_mul(self.swish(mm1), mm2)  # type: ignore[attr-defined]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.n_splits_down_proj == 1:
 | 
					        if self.n_splits_down_proj == 1:
 | 
				
			||||||
| 
						 | 
					@ -527,23 +452,10 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
 | 
					            invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
 | 
				
			||||||
            if mode == "prefill":
 | 
					            hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
 | 
				
			||||||
                down_groupsize = self.intermediate_size // self.n_splits_down_proj
 | 
					                                                 self.n_splits_down_proj, wt_dtype=self.dtype,
 | 
				
			||||||
                hidden_states_to_concat = []
 | 
					                                                 scale_factor=(self.group_size == 0),
 | 
				
			||||||
                for i in range(self.n_splits_down_proj):
 | 
					                                                 is_prefill=(mode == "prefill"))
 | 
				
			||||||
                    sub_mm1 = self.slice(mm1, begin=[0, 0, i * down_groupsize],
 | 
					 | 
				
			||||||
                                         end=[1, seq_len, (i + 1) * down_groupsize])
 | 
					 | 
				
			||||||
                    hidden_states_to_concat.append(
 | 
					 | 
				
			||||||
                        self.linear(
 | 
					 | 
				
			||||||
                            sub_mm1, self.hidden_size, down_groupsize, bias=False,
 | 
					 | 
				
			||||||
                            wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                hidden_states = sum(hidden_states_to_concat)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
 | 
					 | 
				
			||||||
                                                     self.n_splits_down_proj, wt_dtype=self.dtype,
 | 
					 | 
				
			||||||
                                                     scale_factor=(self.group_size == 0))
 | 
					 | 
				
			||||||
        return hidden_states
 | 
					        return hidden_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def layer_norm(self, hidden_states, layernorm_weight):
 | 
					    def layer_norm(self, hidden_states, layernorm_weight):
 | 
				
			||||||
| 
						 | 
					@ -660,9 +572,11 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
                        n_splits: int,
 | 
					                        n_splits: int,
 | 
				
			||||||
                        act_dtype: npt.DTypeLike = np.float16,
 | 
					                        act_dtype: npt.DTypeLike = np.float16,
 | 
				
			||||||
                        wt_dtype: npt.DTypeLike = np.float16,
 | 
					                        wt_dtype: npt.DTypeLike = np.float16,
 | 
				
			||||||
                        scale_factor: bool = False):
 | 
					                        scale_factor: bool = False,
 | 
				
			||||||
 | 
					                        is_prefill: bool = False):
 | 
				
			||||||
        op = super().dq_split_linear(input_node, n_splits, output_channels, input_channels,
 | 
					        op = super().dq_split_linear(input_node, n_splits, output_channels, input_channels,
 | 
				
			||||||
                                     False, act_dtype, wt_dtype, scale_factor)
 | 
					                                     False, act_dtype, wt_dtype, scale_factor,
 | 
				
			||||||
 | 
					                                     is_prefill=is_prefill)
 | 
				
			||||||
        self.linear_ops.append(op)
 | 
					        self.linear_ops.append(op)
 | 
				
			||||||
        return op
 | 
					        return op
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -827,20 +827,40 @@ def run_prefill(
 | 
				
			||||||
        mlp_layer = curr_layer.mlp
 | 
					        mlp_layer = curr_layer.mlp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        weights = []
 | 
					        weights = []
 | 
				
			||||||
 | 
					        if n_splits_linear == 1:
 | 
				
			||||||
 | 
					            for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
				
			||||||
 | 
					                                        attn_layer.k_proj_dq_list,
 | 
				
			||||||
 | 
					                                        attn_layer.v_proj_dq_list,
 | 
				
			||||||
 | 
					                                        attn_layer.o_proj_dq_list,
 | 
				
			||||||
 | 
					                                        mlp_layer.gate_proj_dq_list,
 | 
				
			||||||
 | 
					                                        mlp_layer.up_proj_dq_list):
 | 
				
			||||||
 | 
					                weights.append((q.weight, q.scale))
 | 
				
			||||||
 | 
					                weights.append((k.weight, k.scale))
 | 
				
			||||||
 | 
					                weights.append((v.weight, v.scale))
 | 
				
			||||||
 | 
					                weights.append((o.weight, o.scale))
 | 
				
			||||||
 | 
					                weights.append((g.weight, g.scale))
 | 
				
			||||||
 | 
					                weights.append((u.weight, u.scale))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
				
			||||||
 | 
					                               attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
				
			||||||
 | 
					                               mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
				
			||||||
 | 
					                l_weights = []
 | 
				
			||||||
 | 
					                scales = []
 | 
				
			||||||
 | 
					                for l in layer_list:
 | 
				
			||||||
 | 
					                    l_weights.append(l.weight)
 | 
				
			||||||
 | 
					                    scales.append(l.scale)
 | 
				
			||||||
 | 
					                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
					        if n_splits_down_proj == 1:
 | 
				
			||||||
                           attn_layer.v_proj_dq_list):
 | 
					            for l in mlp_layer.down_proj_dq_list:
 | 
				
			||||||
            weights.append((q.weight, q.scale))
 | 
					                weights.append((l.weight, l.scale))
 | 
				
			||||||
            weights.append((k.weight, k.scale))
 | 
					        else:
 | 
				
			||||||
            weights.append((v.weight, v.scale))
 | 
					            l_weights = []
 | 
				
			||||||
 | 
					            scales = []
 | 
				
			||||||
        for l in attn_layer.o_proj_dq_list:
 | 
					            for l in mlp_layer.down_proj_dq_list:
 | 
				
			||||||
            weights.append((l.weight, l.scale))
 | 
					                l_weights.append(l.weight)
 | 
				
			||||||
        for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
 | 
					                scales.append(l.scale)
 | 
				
			||||||
            weights.append((g.weight, g.scale))
 | 
					            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
				
			||||||
            weights.append((u.weight, u.scale))
 | 
					 | 
				
			||||||
        for l in mlp_layer.down_proj_dq_list:
 | 
					 | 
				
			||||||
            weights.append((l.weight, l.scale))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
					        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
				
			||||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
					        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue