Support Qwen2-7b MLP in int4 and transpose_value_cache=True (#11968)
This commit is contained in:
parent
65e281bb29
commit
c48817bd43
2 changed files with 49 additions and 11 deletions
|
|
@ -65,6 +65,11 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
|
|||
model.llm.config.model_type = "llama"
|
||||
model = model.llm
|
||||
|
||||
if model.config.model_type == "qwen2":
|
||||
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj
|
||||
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_forward
|
||||
model.apply(split_mlp_down_proj)
|
||||
|
||||
# lm_head to cpu optimization
|
||||
if cpu_lm_head:
|
||||
# disable the optimization by default
|
||||
|
|
@ -134,8 +139,6 @@ def optimize_llm(
|
|||
intra_pp = 2
|
||||
if inter_pp is None:
|
||||
inter_pp = 4 if model.config.intermediate_size == 18944 else 1
|
||||
if model.config.intermediate_size == 18944:
|
||||
transpose_value_cache = False
|
||||
|
||||
from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
|
||||
from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner
|
||||
|
|
|
|||
|
|
@ -42,6 +42,30 @@ from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
|
|||
from ipex_llm.transformers.npu_models.common import reshape_lm_head_input
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
|
||||
|
||||
|
||||
def split_mlp_down_proj(module: torch.nn.Module):
|
||||
if isinstance(module, Qwen2MLP) and module.down_proj.in_features == 18944:
|
||||
new_linear_0 = torch.nn.Linear(0, 0, bias=False)
|
||||
new_weight_0 = torch.nn.Parameter(module.down_proj.weight[:, :9472], requires_grad=False)
|
||||
new_linear_0.weight = new_weight_0
|
||||
new_linear_0.in_features = new_weight_0.size(1)
|
||||
new_linear_0.out_features = new_weight_0.size(0)
|
||||
module.down_proj_0 = new_linear_0
|
||||
new_linear_1 = torch.nn.Linear(0, 0, bias=False)
|
||||
new_weight_1 = torch.nn.Parameter(module.down_proj.weight[:, 9472:], requires_grad=False)
|
||||
new_linear_1.weight = new_weight_1
|
||||
new_linear_1.in_features = new_weight_1.size(1)
|
||||
new_linear_1.out_features = new_weight_1.size(0)
|
||||
module.down_proj_1 = new_linear_1
|
||||
|
||||
del module.down_proj
|
||||
|
||||
|
||||
def split_mlp_forward(self, x):
|
||||
h = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
|
||||
return self.down_proj_0(h[:, :, :9472]) + self.down_proj_1(h[:, :, 9472:])
|
||||
|
||||
|
||||
class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
|
||||
|
|
@ -201,7 +225,7 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
|
|||
self.compile()
|
||||
print("end compiling")
|
||||
|
||||
def mlp(self, hidden_states):
|
||||
def mlp(self, hidden_states, seq_len):
|
||||
mm1 = self.linear(
|
||||
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
|
||||
)
|
||||
|
|
@ -211,9 +235,13 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
|
|||
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
|
||||
if self.intermediate_size == 18944:
|
||||
# for qwen2-7b
|
||||
hidden_states = self.linear(
|
||||
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=np.int8
|
||||
)
|
||||
mm1_0 = self.slice(mm1, begin=[0, 0, 0], end=[1, seq_len, 9472])
|
||||
mm1_1 = self.slice(mm1, begin=[0, 0, 9472], end=[1, seq_len, 18944])
|
||||
hidden_states_0 = self.linear(mm1_0, self.hidden_size, 9472,
|
||||
bias=False, wt_dtype=self.dtype)
|
||||
hidden_states_1 = self.linear(mm1_1, self.hidden_size, 9472,
|
||||
bias=False, wt_dtype=self.dtype)
|
||||
hidden_states = hidden_states_0 + hidden_states_1
|
||||
else:
|
||||
hidden_states = self.linear(
|
||||
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype
|
||||
|
|
@ -257,7 +285,7 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
|
|||
hidden_states = self.eltwise_add(residual, attn_output)
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states, self.seq_len)
|
||||
hidden_states = self.eltwise_add(residual, hidden_states)
|
||||
hidden_states = self.convert_to_fp16(hidden_states)
|
||||
|
||||
|
|
@ -343,9 +371,13 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
|
|||
)
|
||||
self.backend_decoders.append(decoder)
|
||||
|
||||
offset = 0
|
||||
for i in range(intra_stages):
|
||||
start, end = self.layer_ranges[i]
|
||||
self.backend_decoders[i].set_weights(self.op_id, op_parameters[start * 7:end * 7])
|
||||
curr_linear_ops = len(self.backend_decoders[i].linear_ops)
|
||||
curr_parameters = self.op_parameters[offset:offset + curr_linear_ops]
|
||||
self.backend_decoders[i].set_weights(self.op_id, curr_parameters)
|
||||
offset = offset + curr_linear_ops
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -543,7 +575,8 @@ def run_decode(
|
|||
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
|
||||
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
|
||||
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
||||
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
|
||||
(mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale),
|
||||
(mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale)
|
||||
]
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
|
|
@ -814,6 +847,8 @@ def run_prefill(
|
|||
transpose_value=transpose_value_cache
|
||||
)
|
||||
convert_forward(model, Qwen2Attention, qwen2_attention_forward)
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
|
||||
convert_forward(model, Qwen2MLP, split_mlp_forward)
|
||||
deocderlayers = model.model.layers
|
||||
|
||||
while True:
|
||||
|
|
@ -836,7 +871,6 @@ def run_prefill(
|
|||
|
||||
hidden_states = layer_outputs[0]
|
||||
next_decoder_cache = layer_outputs[1]
|
||||
|
||||
result_queue.put((hidden_states, next_decoder_cache))
|
||||
|
||||
|
||||
|
|
@ -1124,10 +1158,11 @@ def generate_qwen2_attention_forward(max_seq_len, transpose_value):
|
|||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids)
|
||||
|
||||
cache_kwargs = {"max_seq_len": max_seq_len, "transpose": transpose_value, }
|
||||
|
||||
if past_key_value is not None:
|
||||
if transpose_value:
|
||||
value_states = value_states.transpose(-1, -2)
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, cache_kwargs)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue