Support Qwen2-7b MLP in int4 and transpose_value_cache=True (#11968)

This commit is contained in:
Yang Wang 2024-09-01 23:37:44 -07:00 committed by GitHub
parent 65e281bb29
commit c48817bd43
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 49 additions and 11 deletions

View file

@ -65,6 +65,11 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
model.llm.config.model_type = "llama"
model = model.llm
if model.config.model_type == "qwen2":
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_forward
model.apply(split_mlp_down_proj)
# lm_head to cpu optimization
if cpu_lm_head:
# disable the optimization by default
@ -134,8 +139,6 @@ def optimize_llm(
intra_pp = 2
if inter_pp is None:
inter_pp = 4 if model.config.intermediate_size == 18944 else 1
if model.config.intermediate_size == 18944:
transpose_value_cache = False
from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner

View file

@ -42,6 +42,30 @@ from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
from ipex_llm.transformers.npu_models.common import reshape_lm_head_input
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
def split_mlp_down_proj(module: torch.nn.Module):
if isinstance(module, Qwen2MLP) and module.down_proj.in_features == 18944:
new_linear_0 = torch.nn.Linear(0, 0, bias=False)
new_weight_0 = torch.nn.Parameter(module.down_proj.weight[:, :9472], requires_grad=False)
new_linear_0.weight = new_weight_0
new_linear_0.in_features = new_weight_0.size(1)
new_linear_0.out_features = new_weight_0.size(0)
module.down_proj_0 = new_linear_0
new_linear_1 = torch.nn.Linear(0, 0, bias=False)
new_weight_1 = torch.nn.Parameter(module.down_proj.weight[:, 9472:], requires_grad=False)
new_linear_1.weight = new_weight_1
new_linear_1.in_features = new_weight_1.size(1)
new_linear_1.out_features = new_weight_1.size(0)
module.down_proj_1 = new_linear_1
del module.down_proj
def split_mlp_forward(self, x):
h = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
return self.down_proj_0(h[:, :, :9472]) + self.down_proj_1(h[:, :, 9472:])
class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
@ -201,7 +225,7 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
self.compile()
print("end compiling")
def mlp(self, hidden_states):
def mlp(self, hidden_states, seq_len):
mm1 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
)
@ -211,9 +235,13 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
if self.intermediate_size == 18944:
# for qwen2-7b
hidden_states = self.linear(
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=np.int8
)
mm1_0 = self.slice(mm1, begin=[0, 0, 0], end=[1, seq_len, 9472])
mm1_1 = self.slice(mm1, begin=[0, 0, 9472], end=[1, seq_len, 18944])
hidden_states_0 = self.linear(mm1_0, self.hidden_size, 9472,
bias=False, wt_dtype=self.dtype)
hidden_states_1 = self.linear(mm1_1, self.hidden_size, 9472,
bias=False, wt_dtype=self.dtype)
hidden_states = hidden_states_0 + hidden_states_1
else:
hidden_states = self.linear(
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype
@ -257,7 +285,7 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
hidden_states = self.eltwise_add(residual, attn_output)
residual = hidden_states
hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, self.seq_len)
hidden_states = self.eltwise_add(residual, hidden_states)
hidden_states = self.convert_to_fp16(hidden_states)
@ -343,9 +371,13 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
)
self.backend_decoders.append(decoder)
offset = 0
for i in range(intra_stages):
start, end = self.layer_ranges[i]
self.backend_decoders[i].set_weights(self.op_id, op_parameters[start * 7:end * 7])
curr_linear_ops = len(self.backend_decoders[i].linear_ops)
curr_parameters = self.op_parameters[offset:offset + curr_linear_ops]
self.backend_decoders[i].set_weights(self.op_id, curr_parameters)
offset = offset + curr_linear_ops
def forward(
self,
@ -543,7 +575,8 @@ def run_decode(
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
(mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale),
(mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale)
]
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
@ -814,6 +847,8 @@ def run_prefill(
transpose_value=transpose_value_cache
)
convert_forward(model, Qwen2Attention, qwen2_attention_forward)
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
convert_forward(model, Qwen2MLP, split_mlp_forward)
deocderlayers = model.model.layers
while True:
@ -836,7 +871,6 @@ def run_prefill(
hidden_states = layer_outputs[0]
next_decoder_cache = layer_outputs[1]
result_queue.put((hidden_states, next_decoder_cache))
@ -1124,10 +1158,11 @@ def generate_qwen2_attention_forward(max_seq_len, transpose_value):
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
cache_kwargs = {"max_seq_len": max_seq_len, "transpose": transpose_value, }
if past_key_value is not None:
if transpose_value:
value_states = value_states.transpose(-1, -2)
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)