[NPU] Support llama groupwise (#12260)
* support llama gw * support llama gw lm_head * fix style * remove unused code
This commit is contained in:
parent
48fc63887d
commit
b5e663854b
5 changed files with 143 additions and 74 deletions
|
|
@ -414,7 +414,7 @@ class _BaseAutoModelClass:
|
||||||
optimize_llm(model)
|
optimize_llm(model)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
|
cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
|
||||||
*model_args, **kwargs)
|
quantization_group_size, *model_args, **kwargs)
|
||||||
create_npu_kernels(model)
|
create_npu_kernels(model)
|
||||||
|
|
||||||
if is_sharded:
|
if is_sharded:
|
||||||
|
|
|
||||||
|
|
@ -59,3 +59,23 @@ def split_linear(module, module_name, n_splits=2):
|
||||||
new_linear.weight = torch.nn.Parameter(weight.contiguous(), requires_grad=False)
|
new_linear.weight = torch.nn.Parameter(weight.contiguous(), requires_grad=False)
|
||||||
linear_list.add_module(f"{module_name}_dq_{idx}", new_linear)
|
linear_list.add_module(f"{module_name}_dq_{idx}", new_linear)
|
||||||
return linear_list
|
return linear_list
|
||||||
|
|
||||||
|
|
||||||
|
def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down_proj=2):
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
|
||||||
|
attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||||
|
mlp_module_names = ["down_proj", "up_proj", "gate_proj"]
|
||||||
|
if isinstance(module, (Qwen2Attention, LlamaAttention)):
|
||||||
|
for name in attn_module_names:
|
||||||
|
setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
|
||||||
|
n_splits=n_splits_hidden_size))
|
||||||
|
delattr(module, name)
|
||||||
|
elif isinstance(module, (Qwen2MLP, LlamaMLP)):
|
||||||
|
for name in mlp_module_names:
|
||||||
|
n_splits_mlp = n_splits_hidden_size
|
||||||
|
if name == 'down_proj':
|
||||||
|
n_splits_mlp = n_splits_down_proj
|
||||||
|
setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
|
||||||
|
n_splits=n_splits_mlp))
|
||||||
|
delattr(module, name)
|
||||||
|
|
|
||||||
|
|
@ -87,8 +87,8 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
|
||||||
model.llm.config.model_type = "llama"
|
model.llm.config.model_type = "llama"
|
||||||
model = model.llm
|
model = model.llm
|
||||||
|
|
||||||
if model.config.model_type == "qwen2":
|
if model.config.model_type in ["qwen2", "llama"]:
|
||||||
from ipex_llm.transformers.npu_models.qwen2_mp import split_linears
|
from ipex_llm.transformers.npu_models.common import split_linears
|
||||||
|
|
||||||
if quantization_group_size == 0:
|
if quantization_group_size == 0:
|
||||||
n_splits_linear = 1
|
n_splits_linear = 1
|
||||||
|
|
@ -108,15 +108,19 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
|
||||||
model.apply(lambda m: split_linears(m, n_splits_hidden_size=n_splits_linear,
|
model.apply(lambda m: split_linears(m, n_splits_hidden_size=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj))
|
n_splits_down_proj=n_splits_down_proj))
|
||||||
|
|
||||||
# for Qwen2-7B-Insturct, divide lm_head into 14 parts
|
|
||||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
|
|
||||||
not cpu_lm_head:
|
|
||||||
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
|
|
||||||
if quantization_group_size != 0:
|
if quantization_group_size != 0:
|
||||||
split_num = model.config.hidden_size // quantization_group_size
|
split_num = model.config.hidden_size // quantization_group_size
|
||||||
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
|
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
|
||||||
bias=model.lm_head.bias, use_split=True)
|
bias=model.lm_head.bias, use_split=True)
|
||||||
else:
|
del model.lm_head
|
||||||
|
model.lm_head = new_lm_head
|
||||||
|
|
||||||
|
if model.config.model_type == "qwen2":
|
||||||
|
# for Qwen2-7B-Insturct, divide lm_head into 14 parts
|
||||||
|
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
|
||||||
|
not cpu_lm_head:
|
||||||
|
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
|
||||||
|
if quantization_group_size == 0:
|
||||||
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
|
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
|
||||||
is_split = (not mixed_precision) and qtype == "sym_int4_rtn"
|
is_split = (not mixed_precision) and qtype == "sym_int4_rtn"
|
||||||
split_num = 14 if is_split else 1
|
split_num = 14 if is_split else 1
|
||||||
|
|
@ -163,7 +167,7 @@ def optimize_llm(
|
||||||
if intra_pp is None:
|
if intra_pp is None:
|
||||||
intra_pp = 2
|
intra_pp = 2
|
||||||
if inter_pp is None:
|
if inter_pp is None:
|
||||||
inter_pp = 2
|
inter_pp = 2 if group_size == 0 else 8
|
||||||
|
|
||||||
from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward
|
from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward
|
||||||
from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner
|
from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner
|
||||||
|
|
@ -226,11 +230,6 @@ def optimize_llm(
|
||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
|
||||||
from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
|
from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
|
||||||
convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)
|
convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)
|
||||||
|
|
||||||
# for Qwen2-7B-Insturct, divide lm_head into 14 parts
|
|
||||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
|
|
||||||
isinstance(model.lm_head, SlicedLMHead):
|
|
||||||
model.lm_head.get_fused_lm_head()
|
|
||||||
elif model.config.model_type == "minicpm":
|
elif model.config.model_type == "minicpm":
|
||||||
# for minicpm-1b
|
# for minicpm-1b
|
||||||
if intra_pp is None:
|
if intra_pp is None:
|
||||||
|
|
@ -299,3 +298,6 @@ def optimize_llm(
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
convert_forward(model, module.BaichuanModel, baichuan_model_forward)
|
convert_forward(model, module.BaichuanModel, baichuan_model_forward)
|
||||||
|
|
||||||
|
if isinstance(model.lm_head, SlicedLMHead):
|
||||||
|
model.lm_head.get_fused_lm_head()
|
||||||
|
|
|
||||||
|
|
@ -67,12 +67,18 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
device: str = "NPU",
|
device: str = "NPU",
|
||||||
rms_norm_eps,
|
rms_norm_eps,
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
|
n_splits_linear: int = 1,
|
||||||
|
n_splits_down_proj: int = 1,
|
||||||
|
group_size: int = 0
|
||||||
):
|
):
|
||||||
super().__init__(max_seq_len=max_seq_len,
|
super().__init__(max_seq_len=max_seq_len,
|
||||||
transpose_value=transpose_value,
|
transpose_value=transpose_value,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
profile=profile,
|
profile=profile,
|
||||||
device=device)
|
device=device,
|
||||||
|
n_splits_linear=n_splits_linear,
|
||||||
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
|
group_size=group_size)
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
@ -215,7 +221,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
hidden_states = self.eltwise_add(residual, attn_output)
|
hidden_states = self.eltwise_add(residual, attn_output)
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
|
hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states, self.seq_len, self.mode)
|
||||||
hidden_states = self.eltwise_add(residual, hidden_states)
|
hidden_states = self.eltwise_add(residual, hidden_states)
|
||||||
hidden_states = self.convert_to_fp16(hidden_states)
|
hidden_states = self.convert_to_fp16(hidden_states)
|
||||||
|
|
||||||
|
|
@ -241,6 +247,9 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
max_seq_len: int = 1024,
|
max_seq_len: int = 1024,
|
||||||
transpose_value: bool = False,
|
transpose_value: bool = False,
|
||||||
do_print: bool = False,
|
do_print: bool = False,
|
||||||
|
n_splits_linear: int = 1,
|
||||||
|
n_splits_down_proj: int = 1,
|
||||||
|
group_size: int = 0
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
@ -250,6 +259,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
for w in parameters:
|
for w in parameters:
|
||||||
if isinstance(w, tuple): # from QuantizedLinear
|
if isinstance(w, tuple): # from QuantizedLinear
|
||||||
op_parameters.append((w[0].numpy(), w[1].numpy()))
|
op_parameters.append((w[0].numpy(), w[1].numpy()))
|
||||||
|
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
|
||||||
|
op_parameters.append(w.numpy())
|
||||||
|
elif isinstance(w, np.ndarray): # scale
|
||||||
|
op_parameters.append(w)
|
||||||
else:
|
else:
|
||||||
op_parameters.append(w.to(torch.float16).numpy())
|
op_parameters.append(w.to(torch.float16).numpy())
|
||||||
self.op_parameters = op_parameters
|
self.op_parameters = op_parameters
|
||||||
|
|
@ -258,6 +271,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
self.transpose_value = transpose_value
|
self.transpose_value = transpose_value
|
||||||
if isinstance(parameters[0], tuple):
|
if isinstance(parameters[0], tuple):
|
||||||
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
||||||
|
elif parameters[0].dtype == torch.int8:
|
||||||
|
np_dtype = np.int8
|
||||||
|
elif parameters[0].dtype == torch.uint8:
|
||||||
|
np_dtype = np.uint8
|
||||||
else: # FP16 Linear
|
else: # FP16 Linear
|
||||||
np_dtype = np.float16
|
np_dtype = np.float16
|
||||||
|
|
||||||
|
|
@ -292,6 +309,9 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
mode="decode",
|
mode="decode",
|
||||||
transpose_value=self.transpose_value,
|
transpose_value=self.transpose_value,
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
|
n_splits_linear=n_splits_linear,
|
||||||
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
|
group_size=group_size
|
||||||
)
|
)
|
||||||
self.backend_decoders.append(decoder)
|
self.backend_decoders.append(decoder)
|
||||||
|
|
||||||
|
|
@ -367,6 +387,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
max_seq_len: int = 128,
|
max_seq_len: int = 128,
|
||||||
transpose_value: bool = False,
|
transpose_value: bool = False,
|
||||||
|
n_splits_linear: int = 1,
|
||||||
|
n_splits_down_proj: int = 1,
|
||||||
|
group_size: int = 0
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.op_parameters = parameters
|
self.op_parameters = parameters
|
||||||
|
|
@ -395,6 +418,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||||
mode="prefill",
|
mode="prefill",
|
||||||
transpose_value=self.transpose_value,
|
transpose_value=self.transpose_value,
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
|
n_splits_linear=n_splits_linear,
|
||||||
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
|
group_size=group_size
|
||||||
)
|
)
|
||||||
self.layer_norm_0 = layer_norm_0
|
self.layer_norm_0 = layer_norm_0
|
||||||
self.layer_norm_1 = layer_norm_1
|
self.layer_norm_1 = layer_norm_1
|
||||||
|
|
@ -474,24 +500,53 @@ def run_decode(
|
||||||
head_dim = model.model.layers[layer_start].self_attn.head_dim
|
head_dim = model.model.layers[layer_start].self_attn.head_dim
|
||||||
rms_norm_eps = model.config.rms_norm_eps
|
rms_norm_eps = model.config.rms_norm_eps
|
||||||
intermediate_size = model.config.intermediate_size
|
intermediate_size = model.config.intermediate_size
|
||||||
|
group_size = getattr(model.config, "group_size", 0)
|
||||||
layer_weights = []
|
layer_weights = []
|
||||||
input_layer_norm_weights = []
|
input_layer_norm_weights = []
|
||||||
post_attn_layernorm_weights = []
|
post_attn_layernorm_weights = []
|
||||||
layer_indexs = range(layer_start, layer_end)
|
layer_indexs = range(layer_start, layer_end)
|
||||||
|
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
||||||
|
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
||||||
for layer_idx in layer_indexs:
|
for layer_idx in layer_indexs:
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
mlp_layer = curr_layer.mlp
|
mlp_layer = curr_layer.mlp
|
||||||
|
|
||||||
weights = [
|
weights = []
|
||||||
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
|
if n_splits_linear == 1:
|
||||||
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
|
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||||
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
|
attn_layer.k_proj_dq_list,
|
||||||
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
|
attn_layer.v_proj_dq_list,
|
||||||
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
|
attn_layer.o_proj_dq_list,
|
||||||
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
mlp_layer.gate_proj_dq_list,
|
||||||
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
|
mlp_layer.up_proj_dq_list):
|
||||||
]
|
weights.append((q.weight, q.scale))
|
||||||
|
weights.append((k.weight, k.scale))
|
||||||
|
weights.append((v.weight, v.scale))
|
||||||
|
weights.append((o.weight, o.scale))
|
||||||
|
weights.append((g.weight, g.scale))
|
||||||
|
weights.append((u.weight, u.scale))
|
||||||
|
else:
|
||||||
|
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||||
|
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||||
|
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||||
|
l_weights = []
|
||||||
|
scales = []
|
||||||
|
for l in layer_list:
|
||||||
|
l_weights.append(l.weight)
|
||||||
|
scales.append(l.scale)
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
|
if n_splits_down_proj == 1:
|
||||||
|
for l in mlp_layer.down_proj_dq_list:
|
||||||
|
weights.append((l.weight, l.scale))
|
||||||
|
else:
|
||||||
|
l_weights = []
|
||||||
|
scales = []
|
||||||
|
for l in mlp_layer.down_proj_dq_list:
|
||||||
|
l_weights.append(l.weight)
|
||||||
|
scales.append(l.scale)
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||||
|
|
@ -518,6 +573,9 @@ def run_decode(
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
transpose_value=transpose_value_cache,
|
transpose_value=transpose_value_cache,
|
||||||
do_print=False,
|
do_print=False,
|
||||||
|
n_splits_linear=n_splits_linear,
|
||||||
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
|
group_size=group_size
|
||||||
)
|
)
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
@ -591,11 +649,15 @@ class DecodeRunner:
|
||||||
|
|
||||||
self.forward_signal = torch.tensor(0, dtype=torch.int)
|
self.forward_signal = torch.tensor(0, dtype=torch.int)
|
||||||
|
|
||||||
|
n_layers_per_rank = num_layers // (world_size - 1)
|
||||||
|
if num_layers % (world_size - 1) > 0:
|
||||||
|
n_layers_per_rank += 1
|
||||||
|
|
||||||
for rank in range(1, world_size):
|
for rank in range(1, world_size):
|
||||||
input_q = mp.Queue()
|
input_q = mp.Queue()
|
||||||
output_q = mp.Queue()
|
output_q = mp.Queue()
|
||||||
start_layer = (rank - 1) * (num_layers // (world_size - 1))
|
start_layer = (rank - 1) * n_layers_per_rank
|
||||||
end_layer = (rank) * (num_layers // (world_size - 1))
|
end_layer = (rank) * n_layers_per_rank
|
||||||
if rank == world_size - 1:
|
if rank == world_size - 1:
|
||||||
end_layer = num_layers
|
end_layer = num_layers
|
||||||
p = mp.Process(
|
p = mp.Process(
|
||||||
|
|
@ -676,25 +738,34 @@ def run_prefill(
|
||||||
head_dim = model.model.layers[layer_start].self_attn.head_dim
|
head_dim = model.model.layers[layer_start].self_attn.head_dim
|
||||||
rms_norm_eps = model.config.rms_norm_eps
|
rms_norm_eps = model.config.rms_norm_eps
|
||||||
intermediate_size = model.config.intermediate_size
|
intermediate_size = model.config.intermediate_size
|
||||||
|
group_size = getattr(model.config, "group_size", 0)
|
||||||
deocderlayers = []
|
deocderlayers = []
|
||||||
layer_weights = []
|
layer_weights = []
|
||||||
input_layer_norm_weights = []
|
input_layer_norm_weights = []
|
||||||
post_attn_layernorm_weights = []
|
post_attn_layernorm_weights = []
|
||||||
layer_indexs = range(layer_start, layer_end)
|
layer_indexs = range(layer_start, layer_end)
|
||||||
|
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
||||||
|
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
||||||
for layer_idx in layer_indexs:
|
for layer_idx in layer_indexs:
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
mlp_layer = curr_layer.mlp
|
mlp_layer = curr_layer.mlp
|
||||||
|
|
||||||
weights = [
|
weights = []
|
||||||
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
|
|
||||||
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
|
for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||||
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
|
attn_layer.v_proj_dq_list):
|
||||||
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
|
weights.append((q.weight, q.scale))
|
||||||
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
|
weights.append((k.weight, k.scale))
|
||||||
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
weights.append((v.weight, v.scale))
|
||||||
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
|
|
||||||
]
|
for l in attn_layer.o_proj_dq_list:
|
||||||
|
weights.append((l.weight, l.scale))
|
||||||
|
for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
|
||||||
|
weights.append((g.weight, g.scale))
|
||||||
|
weights.append((u.weight, u.scale))
|
||||||
|
for l in mlp_layer.down_proj_dq_list:
|
||||||
|
weights.append((l.weight, l.scale))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||||
|
|
@ -715,6 +786,9 @@ def run_prefill(
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
max_seq_len=max_output_len,
|
max_seq_len=max_output_len,
|
||||||
transpose_value=transpose_value_cache,
|
transpose_value=transpose_value_cache,
|
||||||
|
n_splits_linear=n_splits_linear,
|
||||||
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
|
group_size=group_size
|
||||||
)
|
)
|
||||||
|
|
||||||
layer_weights.extend(weights)
|
layer_weights.extend(weights)
|
||||||
|
|
|
||||||
|
|
@ -42,27 +42,8 @@ from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
|
||||||
from ipex_llm.transformers.npu_models.common import reshape_lm_head_input
|
from ipex_llm.transformers.npu_models.common import reshape_lm_head_input
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
|
||||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||||
from ipex_llm.transformers.npu_models.common import split_linear
|
|
||||||
|
|
||||||
|
|
||||||
def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down_proj=2):
|
|
||||||
attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
|
||||||
mlp_module_names = ["down_proj", "up_proj", "gate_proj"]
|
|
||||||
if isinstance(module, Qwen2Attention):
|
|
||||||
for name in attn_module_names:
|
|
||||||
setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
|
|
||||||
n_splits=n_splits_hidden_size))
|
|
||||||
delattr(module, name)
|
|
||||||
elif isinstance(module, Qwen2MLP):
|
|
||||||
for name in mlp_module_names:
|
|
||||||
n_splits_mlp = n_splits_hidden_size
|
|
||||||
if name == 'down_proj':
|
|
||||||
n_splits_mlp = n_splits_down_proj
|
|
||||||
setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
|
|
||||||
n_splits=n_splits_mlp))
|
|
||||||
delattr(module, name)
|
|
||||||
|
|
||||||
|
|
||||||
def split_mlp_down_proj(module: torch.nn.Module):
|
def split_mlp_down_proj(module: torch.nn.Module):
|
||||||
|
|
@ -594,30 +575,22 @@ def run_decode(
|
||||||
|
|
||||||
weights = []
|
weights = []
|
||||||
if n_splits_linear == 1:
|
if n_splits_linear == 1:
|
||||||
for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||||
attn_layer.v_proj_dq_list):
|
attn_layer.k_proj_dq_list,
|
||||||
|
attn_layer.v_proj_dq_list,
|
||||||
|
attn_layer.o_proj_dq_list,
|
||||||
|
mlp_layer.gate_proj_dq_list,
|
||||||
|
mlp_layer.up_proj_dq_list):
|
||||||
weights.append((q.weight, q.scale))
|
weights.append((q.weight, q.scale))
|
||||||
weights.append((k.weight, k.scale))
|
weights.append((k.weight, k.scale))
|
||||||
weights.append((v.weight, v.scale))
|
weights.append((v.weight, v.scale))
|
||||||
|
weights.append((o.weight, o.scale))
|
||||||
for l in attn_layer.o_proj_dq_list:
|
|
||||||
weights.append((l.weight, l.scale))
|
|
||||||
else:
|
|
||||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
|
||||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list]:
|
|
||||||
l_weights = []
|
|
||||||
scales = []
|
|
||||||
for l in layer_list:
|
|
||||||
l_weights.append(l.weight)
|
|
||||||
scales.append(l.scale)
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
|
||||||
|
|
||||||
if n_splits_linear == 1:
|
|
||||||
for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
|
|
||||||
weights.append((g.weight, g.scale))
|
weights.append((g.weight, g.scale))
|
||||||
weights.append((u.weight, u.scale))
|
weights.append((u.weight, u.scale))
|
||||||
else:
|
else:
|
||||||
for layer_list in [mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||||
|
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||||
|
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||||
l_weights = []
|
l_weights = []
|
||||||
scales = []
|
scales = []
|
||||||
for l in layer_list:
|
for l in layer_list:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue