[NPU] support asym_int4 for llama (#12556)

* add llama-imatrix

* fix bugs in llama.py

* style fix
This commit is contained in:
Zijie Li 2024-12-17 01:01:17 -05:00 committed by GitHub
parent d127a8654c
commit fcb474820d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 124 additions and 32 deletions

View file

@ -72,6 +72,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
group_size: int = 0, group_size: int = 0,
cos_len: int = 1, cos_len: int = 1,
keep_position_ids=True, keep_position_ids=True,
asym: bool = False,
): ):
super().__init__(max_seq_len=max_seq_len, super().__init__(max_seq_len=max_seq_len,
transpose_value=transpose_value, transpose_value=transpose_value,
@ -80,7 +81,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
device=device, device=device,
n_splits_linear=n_splits_linear, n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj, n_splits_down_proj=n_splits_down_proj,
group_size=group_size) group_size=group_size,
asym=asym)
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.dtype = dtype self.dtype = dtype
@ -278,7 +280,8 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
do_print: bool = False, do_print: bool = False,
n_splits_linear: int = 1, n_splits_linear: int = 1,
n_splits_down_proj: int = 1, n_splits_down_proj: int = 1,
group_size: int = 0 group_size: int = 0,
asym: bool = False,
): ):
super().__init__() super().__init__()
@ -286,8 +289,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
op_parameters = [] op_parameters = []
for w in parameters: for w in parameters:
if isinstance(w, tuple): # from QuantizedLinear if isinstance(w, tuple) and not asym: # from QuantizedLinear
op_parameters.append((w[0].numpy(), w[1].numpy())) op_parameters.append((w[0].numpy(), w[1].numpy()))
elif isinstance(w, tuple) and asym: # from QuantizedLinear
op_parameters.append((w[0].numpy(), w[1].numpy(), w[2].numpy()))
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
op_parameters.append(w.numpy()) op_parameters.append(w.numpy())
elif isinstance(w, np.ndarray): # scale elif isinstance(w, np.ndarray): # scale
@ -341,7 +346,8 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
dtype=np_dtype, dtype=np_dtype,
n_splits_linear=n_splits_linear, n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj, n_splits_down_proj=n_splits_down_proj,
group_size=group_size group_size=group_size,
asym=asym,
) )
self.backend_decoders.append(decoder) self.backend_decoders.append(decoder)
@ -427,6 +433,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
n_splits_down_proj: int = 1, n_splits_down_proj: int = 1,
group_size: int = 0, group_size: int = 0,
cos_len: int = 1, cos_len: int = 1,
asym: bool = False,
): ):
super().__init__() super().__init__()
self.op_parameters = parameters self.op_parameters = parameters
@ -460,6 +467,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
n_splits_down_proj=n_splits_down_proj, n_splits_down_proj=n_splits_down_proj,
group_size=group_size, group_size=group_size,
cos_len=cos_len, cos_len=cos_len,
asym=asym,
) )
self.layer_norm_0 = layer_norm_0 self.layer_norm_0 = layer_norm_0
self.layer_norm_1 = layer_norm_1 self.layer_norm_1 = layer_norm_1
@ -555,6 +563,7 @@ def run_decode(
layer_indexs = range(layer_start, layer_end) layer_indexs = range(layer_start, layer_end)
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
asym = getattr(model.config, "asym", False)
for layer_idx in layer_indexs: for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx] curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn attn_layer = curr_layer.self_attn
@ -567,9 +576,16 @@ def run_decode(
mlp_layer.down_proj_dq_list]: mlp_layer.down_proj_dq_list]:
l_weights = [] l_weights = []
scales = [] scales = []
zeros = []
for l in layer_list: for l in layer_list:
l_weights.append(l.weight) l_weights.append(l.weight)
scales.append(l.scale) scales.append(l.scale)
if l.zero is not None:
zeros.append(l.zero)
if len(zeros):
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
torch.stack(zeros, axis=0)))
else:
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
@ -603,7 +619,8 @@ def run_decode(
do_print=False, do_print=False,
n_splits_linear=n_splits_linear, n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj, n_splits_down_proj=n_splits_down_proj,
group_size=group_size group_size=group_size,
asym=asym,
) )
dist.barrier() dist.barrier()
@ -814,6 +831,7 @@ def run_prefill(
layer_indexs = range(layer_start, layer_end) layer_indexs = range(layer_start, layer_end)
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
asym = getattr(model.config, "asym", False)
for layer_idx in layer_indexs: for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx] curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn attn_layer = curr_layer.self_attn
@ -827,10 +845,18 @@ def run_prefill(
mlp_layer.down_proj_dq_list]: mlp_layer.down_proj_dq_list]:
l_weights = [] l_weights = []
scales = [] scales = []
zeros = []
for l in layer_list: for l in layer_list:
l_weights.append(l.weight) l_weights.append(l.weight)
scales.append(l.scale) scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) if l.zero is not None:
zeros.append(l.zero)
if len(zeros):
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
torch.stack(zeros, axis=0)))
else:
weights.append((torch.stack(l_weights, axis=0),
torch.stack(scales, axis=0)))
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
@ -859,6 +885,7 @@ def run_prefill(
n_splits_down_proj=n_splits_down_proj, n_splits_down_proj=n_splits_down_proj,
group_size=group_size, group_size=group_size,
cos_len=cos_len, cos_len=cos_len,
asym=asym,
) )
layer_weights.extend(weights) layer_weights.extend(weights)

View file

@ -130,15 +130,29 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
vocab_size = model.config.vocab_size vocab_size = model.config.vocab_size
model_norm = model.model.norm model_norm = model.model.norm
lm_head = model.lm_head lm_head = model.lm_head
asym = getattr(model.config, "asym", False)
if n_splits_linear == 1: if n_splits_linear == 1:
asym = lm_head.qtype == "asym_int4_rtn"
if asym:
weights = [(lm_head.weight, lm_head.scale, lm_head.zero)]
else:
weights = [(lm_head.weight, lm_head.scale)] weights = [(lm_head.weight, lm_head.scale)]
else: else:
lm_heads = lm_head.lm_heads lm_heads = lm_head.lm_heads
asym = lm_heads[0].qtype == "asym_int4_rtn"
lm_head_weights = [] lm_head_weights = []
scales = [] scales = []
for i in range(n_splits_linear): zeros = []
lm_head_weights.append(lm_heads[i].weight) for l in lm_heads:
scales.append(lm_heads[i].scale) lm_head_weights.append(l.weight)
scales.append(l.scale)
if l.zero is not None:
zeros.append(l.zero)
if len(zeros):
weights = [(torch.stack(lm_head_weights, axis=0),
torch.stack(scales, axis=0),
torch.stack(zeros, axis=0))]
else:
weights = [(torch.stack(lm_head_weights, axis=0), weights = [(torch.stack(lm_head_weights, axis=0),
torch.stack(scales, axis=0))] torch.stack(scales, axis=0))]
if isinstance(weights[0], tuple): if isinstance(weights[0], tuple):
@ -156,16 +170,23 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
dtype=np_dtype, dtype=np_dtype,
model_norm_weight=model_norm.weight.to(torch.float16), model_norm_weight=model_norm.weight.to(torch.float16),
vocab_size=vocab_size, vocab_size=vocab_size,
n_splits=n_splits_linear n_splits=n_splits_linear,
asym=asym
) )
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir, last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir,
True, False) True, False)
# save weights bins files # save weights bins files
if n_splits_linear == 1: if n_splits_linear == 1:
if not asym:
weight_numpy = [ weight_numpy = [
lm_head.weight.data.numpy(), lm_head.scale.data.numpy(), lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
] ]
else:
weight_numpy = [
lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
lm_head.zero.data.numpy()
]
else: else:
weight_numpy = [v.numpy() for v in weights[0]] weight_numpy = [v.numpy() for v in weights[0]]
@ -234,6 +255,7 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
head_dim = model.model.layers[0].self_attn.head_dim head_dim = model.model.layers[0].self_attn.head_dim
intermediate_size = model.config.intermediate_size intermediate_size = model.config.intermediate_size
rms_norm_eps = model.config.rms_norm_eps rms_norm_eps = model.config.rms_norm_eps
asym = getattr(model.config, "asym", False)
from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer
curr_layer = model.model.layers[layer_idx] curr_layer = model.model.layers[layer_idx]
@ -247,9 +269,16 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
mlp_layer.down_proj_dq_list]: mlp_layer.down_proj_dq_list]:
l_weights = [] l_weights = []
scales = [] scales = []
zeros = []
for l in layer_list: for l in layer_list:
l_weights.append(l.weight) l_weights.append(l.weight)
scales.append(l.scale) scales.append(l.scale)
if l.zero is not None:
zeros.append(l.zero)
if len(zeros):
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
torch.stack(zeros, axis=0)))
else:
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
@ -299,7 +328,8 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
n_splits_down_proj=n_splits_down_proj, n_splits_down_proj=n_splits_down_proj,
group_size=group_size, group_size=group_size,
cos_len=input_len, cos_len=input_len,
keep_position_ids=keep_position_ids keep_position_ids=keep_position_ids,
asym=asym
) )
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
@ -329,11 +359,24 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
layer_norm_0.data.numpy().tofile(input_lm_bin_file) layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file) layer_norm_1.data.numpy().tofile(post_lm_bin_file)
st_idx = 8 st_idx = 8
if not asym:
for idx, (weight, scale) in enumerate(weights): for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
weight.numpy().tofile(bin_file) weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") bin_file = os.path.join(weight_dir,
f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
scale.numpy().tofile(bin_file) scale.numpy().tofile(bin_file)
else:
for idx, (weight, scale, zero) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*3}.bin")
weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir,
f"model_{layer_idx}_input_{st_idx+idx*3+1}.bin")
scale.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir,
f"model_{layer_idx}_input_{st_idx+idx*3+2}.bin")
zero.numpy().tofile(bin_file)
del single_decoder del single_decoder
@ -347,6 +390,7 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow
rms_norm_eps = model.config.rms_norm_eps rms_norm_eps = model.config.rms_norm_eps
layer_num = len(model.model.layers) layer_num = len(model.model.layers)
fused_layer_num = layer_num // fused_layers fused_layer_num = layer_num // fused_layers
asym = getattr(model.config, "asym", False)
from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer
for i in range(fused_layers): for i in range(fused_layers):
@ -370,9 +414,16 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow
mlp_layer.down_proj_dq_list]: mlp_layer.down_proj_dq_list]:
l_weights = [] l_weights = []
scales = [] scales = []
zeros = []
for l in layer_list: for l in layer_list:
l_weights.append(l.weight) l_weights.append(l.weight)
scales.append(l.scale) scales.append(l.scale)
if l.zero is not None:
zeros.append(l.zero)
if len(zeros):
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
torch.stack(zeros, axis=0)))
else:
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
@ -397,12 +448,25 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow
layer_norm_1.data.numpy().tofile(post_lm_bin_file) layer_norm_1.data.numpy().tofile(post_lm_bin_file)
st_idx = 5 st_idx = 5
# 6, 7 are past k/v # 6, 7 are past k/v
if not asym:
for idx, (weight, scale) in enumerate(weights): for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") bin_file = os.path.join(weight_dir,
f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
weight.numpy().tofile(bin_file) weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir, bin_file = os.path.join(weight_dir,
f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
scale.numpy().tofile(bin_file) scale.numpy().tofile(bin_file)
else:
for idx, (weight, scale, zero) in enumerate(weights):
bin_file = os.path.join(weight_dir,
f"model_{layer_idx}_input_{st_idx+idx*3}.bin")
weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir,
f"model_{layer_idx}_input_{st_idx+idx*3+1}.bin")
scale.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir,
f"model_{layer_idx}_input_{st_idx+idx*3+2}.bin")
zero.numpy().tofile(bin_file)
if isinstance(weights[0], tuple): if isinstance(weights[0], tuple):
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
@ -426,7 +490,8 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow
dtype=np_dtype, dtype=np_dtype,
n_splits_linear=n_splits_linear, n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj, n_splits_down_proj=n_splits_down_proj,
group_size=group_size group_size=group_size,
asym=asym
) )
update_names_of_IR_and_export_blob(fused_decoder, update_names_of_IR_and_export_blob(fused_decoder,
f"decoder_layer_{i}", f"decoder_layer_{i}",