[NPU] support asym_int4 for baichuan (#12576)
* add npu support for baichuan * Update baichuan_mp.py * Update baichuan_mp.py
This commit is contained in:
		
							parent
							
								
									098eb335b2
								
							
						
					
					
						commit
						c410d9cf73
					
				
					 1 changed files with 42 additions and 13 deletions
				
			
		| 
						 | 
				
			
			@ -80,7 +80,8 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        intermediate_size,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
        group_size: int = 0,
 | 
			
		||||
        asym: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(max_seq_len=max_seq_len,
 | 
			
		||||
                         transpose_value=transpose_value,
 | 
			
		||||
| 
						 | 
				
			
			@ -89,7 +90,8 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
                         device=device,
 | 
			
		||||
                         n_splits_linear=n_splits_linear,
 | 
			
		||||
                         n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                         group_size=group_size)
 | 
			
		||||
                         group_size=group_size,
 | 
			
		||||
                         asym=asym)
 | 
			
		||||
        self.max_seq_len = max_seq_len
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.dtype = dtype
 | 
			
		||||
| 
						 | 
				
			
			@ -100,6 +102,7 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        self.rms_norm_eps = rms_norm_eps
 | 
			
		||||
        self.transpose_value = transpose_value
 | 
			
		||||
        self.num_layers = num_layers
 | 
			
		||||
        self.asym = asym
 | 
			
		||||
 | 
			
		||||
        cos = self.constant(self.cached_cos)
 | 
			
		||||
        self.cos = self.unsqueeze(cos, axis=0)
 | 
			
		||||
| 
						 | 
				
			
			@ -232,7 +235,8 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
            wt_dtype=self.dtype,
 | 
			
		||||
            n_splits=self.n_splits_linear,
 | 
			
		||||
            scale_factor=(self.group_size == 0),
 | 
			
		||||
            is_prefill=(mode == "prefill")
 | 
			
		||||
            is_prefill=(mode == "prefill"),
 | 
			
		||||
            asym=self.asym
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        proj = self.reshape(proj, [-1, 3, hidden_size])  # b*s, 3, h
 | 
			
		||||
| 
						 | 
				
			
			@ -300,7 +304,8 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
            attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype,
 | 
			
		||||
            n_splits=self.n_splits_linear,
 | 
			
		||||
            scale_factor=(self.group_size == 0),
 | 
			
		||||
            is_prefill=(mode == "prefill")
 | 
			
		||||
            is_prefill=(mode == "prefill"),
 | 
			
		||||
            asym=self.asym
 | 
			
		||||
        )
 | 
			
		||||
        return attn_output, new_key_states, new_value_states
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -368,7 +373,8 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        do_print: bool = False,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
        group_size: int = 0,
 | 
			
		||||
        asym: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -376,8 +382,10 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
        op_parameters = []
 | 
			
		||||
        for w in parameters:
 | 
			
		||||
            if isinstance(w, tuple):  # from QuantizedLinear
 | 
			
		||||
            if isinstance(w, tuple) and not asym:  # from QuantizedLinear
 | 
			
		||||
                op_parameters.append((w[0].numpy(), w[1].numpy()))
 | 
			
		||||
            elif isinstance(w, tuple) and asym:  # from QuantizedLinear
 | 
			
		||||
                op_parameters.append((w[0].numpy(), w[1].numpy(),  w[2].numpy()))
 | 
			
		||||
            elif w.dtype in [torch.int8, torch.uint8]:    # QuantizedLinear weight
 | 
			
		||||
                op_parameters.append(w.numpy())
 | 
			
		||||
            elif isinstance(w, np.ndarray):     # scale
 | 
			
		||||
| 
						 | 
				
			
			@ -430,7 +438,8 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
                dtype=np_dtype,
 | 
			
		||||
                n_splits_linear=n_splits_linear,
 | 
			
		||||
                n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                group_size=group_size
 | 
			
		||||
                group_size=group_size,
 | 
			
		||||
                asym=asym,
 | 
			
		||||
            )
 | 
			
		||||
            self.backend_decoders.append(decoder)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -506,7 +515,8 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        transpose_value: bool = False,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
        group_size: int = 0,
 | 
			
		||||
        asym: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.op_parameters = parameters
 | 
			
		||||
| 
						 | 
				
			
			@ -537,7 +547,8 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
            dtype=np_dtype,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
            group_size=group_size,
 | 
			
		||||
            asym=asym
 | 
			
		||||
        )
 | 
			
		||||
        self.layer_norm_0 = layer_norm_0
 | 
			
		||||
        self.layer_norm_1 = layer_norm_1
 | 
			
		||||
| 
						 | 
				
			
			@ -620,6 +631,7 @@ def run_decode(
 | 
			
		|||
    layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
    n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
    n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
    asym = getattr(model.config, "asym", False)
 | 
			
		||||
    for layer_idx in layer_indexs:
 | 
			
		||||
        curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
        attn_layer = curr_layer.self_attn
 | 
			
		||||
| 
						 | 
				
			
			@ -631,10 +643,17 @@ def run_decode(
 | 
			
		|||
                           mlp_layer.down_proj_dq_list]:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            zeros = []
 | 
			
		||||
            for l in layer_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
                if l.zero is not None:
 | 
			
		||||
                    zeros.append(l.zero)
 | 
			
		||||
            if len(zeros):
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
 | 
			
		||||
                                torch.stack(zeros, axis=0)))
 | 
			
		||||
            else:
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -663,7 +682,8 @@ def run_decode(
 | 
			
		|||
        do_print=False,
 | 
			
		||||
        n_splits_linear=n_splits_linear,
 | 
			
		||||
        n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
        group_size=group_size
 | 
			
		||||
        group_size=group_size,
 | 
			
		||||
        asym=asym,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    dist.barrier()
 | 
			
		||||
| 
						 | 
				
			
			@ -827,6 +847,7 @@ def run_prefill(
 | 
			
		|||
    layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
    n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
    n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
    asym = getattr(model.config, "asym", False)
 | 
			
		||||
    for layer_idx in layer_indexs:
 | 
			
		||||
        curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
        attn_layer = curr_layer.self_attn
 | 
			
		||||
| 
						 | 
				
			
			@ -838,10 +859,17 @@ def run_prefill(
 | 
			
		|||
                           mlp_layer.down_proj_dq_list]:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            zeros = []
 | 
			
		||||
            for l in layer_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
                if l.zero is not None:
 | 
			
		||||
                    zeros.append(l.zero)
 | 
			
		||||
            if len(zeros):
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
 | 
			
		||||
                                torch.stack(zeros, axis=0)))
 | 
			
		||||
            else:
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -864,7 +892,8 @@ def run_prefill(
 | 
			
		|||
            transpose_value=transpose_value_cache,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
            group_size=group_size,
 | 
			
		||||
            asym=asym
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        layer_weights.extend(weights)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue