[NPU] support asym_int4 for baichuan (#12576)
* add npu support for baichuan * Update baichuan_mp.py * Update baichuan_mp.py
This commit is contained in:
parent
098eb335b2
commit
c410d9cf73
1 changed files with 42 additions and 13 deletions
|
|
@ -80,7 +80,8 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
n_splits_linear: int = 1,
|
n_splits_linear: int = 1,
|
||||||
n_splits_down_proj: int = 1,
|
n_splits_down_proj: int = 1,
|
||||||
group_size: int = 0
|
group_size: int = 0,
|
||||||
|
asym: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(max_seq_len=max_seq_len,
|
super().__init__(max_seq_len=max_seq_len,
|
||||||
transpose_value=transpose_value,
|
transpose_value=transpose_value,
|
||||||
|
|
@ -89,7 +90,8 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
device=device,
|
device=device,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size)
|
group_size=group_size,
|
||||||
|
asym=asym)
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
@ -100,6 +102,7 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
self.rms_norm_eps = rms_norm_eps
|
self.rms_norm_eps = rms_norm_eps
|
||||||
self.transpose_value = transpose_value
|
self.transpose_value = transpose_value
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
self.asym = asym
|
||||||
|
|
||||||
cos = self.constant(self.cached_cos)
|
cos = self.constant(self.cached_cos)
|
||||||
self.cos = self.unsqueeze(cos, axis=0)
|
self.cos = self.unsqueeze(cos, axis=0)
|
||||||
|
|
@ -232,7 +235,8 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
wt_dtype=self.dtype,
|
wt_dtype=self.dtype,
|
||||||
n_splits=self.n_splits_linear,
|
n_splits=self.n_splits_linear,
|
||||||
scale_factor=(self.group_size == 0),
|
scale_factor=(self.group_size == 0),
|
||||||
is_prefill=(mode == "prefill")
|
is_prefill=(mode == "prefill"),
|
||||||
|
asym=self.asym
|
||||||
)
|
)
|
||||||
|
|
||||||
proj = self.reshape(proj, [-1, 3, hidden_size]) # b*s, 3, h
|
proj = self.reshape(proj, [-1, 3, hidden_size]) # b*s, 3, h
|
||||||
|
|
@ -300,7 +304,8 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype,
|
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype,
|
||||||
n_splits=self.n_splits_linear,
|
n_splits=self.n_splits_linear,
|
||||||
scale_factor=(self.group_size == 0),
|
scale_factor=(self.group_size == 0),
|
||||||
is_prefill=(mode == "prefill")
|
is_prefill=(mode == "prefill"),
|
||||||
|
asym=self.asym
|
||||||
)
|
)
|
||||||
return attn_output, new_key_states, new_value_states
|
return attn_output, new_key_states, new_value_states
|
||||||
|
|
||||||
|
|
@ -368,7 +373,8 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
do_print: bool = False,
|
do_print: bool = False,
|
||||||
n_splits_linear: int = 1,
|
n_splits_linear: int = 1,
|
||||||
n_splits_down_proj: int = 1,
|
n_splits_down_proj: int = 1,
|
||||||
group_size: int = 0
|
group_size: int = 0,
|
||||||
|
asym: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
@ -376,8 +382,10 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
|
|
||||||
op_parameters = []
|
op_parameters = []
|
||||||
for w in parameters:
|
for w in parameters:
|
||||||
if isinstance(w, tuple): # from QuantizedLinear
|
if isinstance(w, tuple) and not asym: # from QuantizedLinear
|
||||||
op_parameters.append((w[0].numpy(), w[1].numpy()))
|
op_parameters.append((w[0].numpy(), w[1].numpy()))
|
||||||
|
elif isinstance(w, tuple) and asym: # from QuantizedLinear
|
||||||
|
op_parameters.append((w[0].numpy(), w[1].numpy(), w[2].numpy()))
|
||||||
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
|
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
|
||||||
op_parameters.append(w.numpy())
|
op_parameters.append(w.numpy())
|
||||||
elif isinstance(w, np.ndarray): # scale
|
elif isinstance(w, np.ndarray): # scale
|
||||||
|
|
@ -430,7 +438,8 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym,
|
||||||
)
|
)
|
||||||
self.backend_decoders.append(decoder)
|
self.backend_decoders.append(decoder)
|
||||||
|
|
||||||
|
|
@ -506,7 +515,8 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
|
||||||
transpose_value: bool = False,
|
transpose_value: bool = False,
|
||||||
n_splits_linear: int = 1,
|
n_splits_linear: int = 1,
|
||||||
n_splits_down_proj: int = 1,
|
n_splits_down_proj: int = 1,
|
||||||
group_size: int = 0
|
group_size: int = 0,
|
||||||
|
asym: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.op_parameters = parameters
|
self.op_parameters = parameters
|
||||||
|
|
@ -537,7 +547,8 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
self.layer_norm_0 = layer_norm_0
|
self.layer_norm_0 = layer_norm_0
|
||||||
self.layer_norm_1 = layer_norm_1
|
self.layer_norm_1 = layer_norm_1
|
||||||
|
|
@ -620,6 +631,7 @@ def run_decode(
|
||||||
layer_indexs = range(layer_start, layer_end)
|
layer_indexs = range(layer_start, layer_end)
|
||||||
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
||||||
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
||||||
|
asym = getattr(model.config, "asym", False)
|
||||||
for layer_idx in layer_indexs:
|
for layer_idx in layer_indexs:
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
|
|
@ -631,9 +643,16 @@ def run_decode(
|
||||||
mlp_layer.down_proj_dq_list]:
|
mlp_layer.down_proj_dq_list]:
|
||||||
l_weights = []
|
l_weights = []
|
||||||
scales = []
|
scales = []
|
||||||
|
zeros = []
|
||||||
for l in layer_list:
|
for l in layer_list:
|
||||||
l_weights.append(l.weight)
|
l_weights.append(l.weight)
|
||||||
scales.append(l.scale)
|
scales.append(l.scale)
|
||||||
|
if l.zero is not None:
|
||||||
|
zeros.append(l.zero)
|
||||||
|
if len(zeros):
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
||||||
|
torch.stack(zeros, axis=0)))
|
||||||
|
else:
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
|
@ -663,7 +682,8 @@ def run_decode(
|
||||||
do_print=False,
|
do_print=False,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym,
|
||||||
)
|
)
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
@ -827,6 +847,7 @@ def run_prefill(
|
||||||
layer_indexs = range(layer_start, layer_end)
|
layer_indexs = range(layer_start, layer_end)
|
||||||
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
||||||
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
||||||
|
asym = getattr(model.config, "asym", False)
|
||||||
for layer_idx in layer_indexs:
|
for layer_idx in layer_indexs:
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
|
|
@ -838,9 +859,16 @@ def run_prefill(
|
||||||
mlp_layer.down_proj_dq_list]:
|
mlp_layer.down_proj_dq_list]:
|
||||||
l_weights = []
|
l_weights = []
|
||||||
scales = []
|
scales = []
|
||||||
|
zeros = []
|
||||||
for l in layer_list:
|
for l in layer_list:
|
||||||
l_weights.append(l.weight)
|
l_weights.append(l.weight)
|
||||||
scales.append(l.scale)
|
scales.append(l.scale)
|
||||||
|
if l.zero is not None:
|
||||||
|
zeros.append(l.zero)
|
||||||
|
if len(zeros):
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
||||||
|
torch.stack(zeros, axis=0)))
|
||||||
|
else:
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
|
@ -864,7 +892,8 @@ def run_prefill(
|
||||||
transpose_value=transpose_value_cache,
|
transpose_value=transpose_value_cache,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
|
|
||||||
layer_weights.extend(weights)
|
layer_weights.extend(weights)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue