Revert prefill logic of qwen2-7b (#11992)
This commit is contained in:
parent
659d15defc
commit
01099f08ee
1 changed files with 44 additions and 123 deletions
|
|
@ -801,13 +801,13 @@ def run_prefill(
|
|||
input_layer_norm_weights = []
|
||||
post_attn_layernorm_weights = []
|
||||
layer_indexs = range(layer_start, layer_end)
|
||||
if model.config.intermediate_size == 8960:
|
||||
# for qwen2-1.5b
|
||||
for layer_idx in layer_indexs:
|
||||
curr_layer = model.model.layers[layer_idx]
|
||||
attn_layer = curr_layer.self_attn
|
||||
mlp_layer = curr_layer.mlp
|
||||
|
||||
if model.config.intermediate_size == 8960:
|
||||
# for qwen2-1.5b
|
||||
weights = [
|
||||
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
|
||||
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
|
||||
|
|
@ -817,6 +817,18 @@ def run_prefill(
|
|||
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
||||
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
|
||||
]
|
||||
elif model.config.intermediate_size == 18944:
|
||||
# for qwen2-7b
|
||||
weights = [
|
||||
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
|
||||
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
|
||||
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
|
||||
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
|
||||
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
|
||||
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
||||
(mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale),
|
||||
(mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale)
|
||||
]
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
|
|
@ -851,19 +863,6 @@ def run_prefill(
|
|||
print("finish creating all decode layers in prefill")
|
||||
result_queue.put("loading finish")
|
||||
|
||||
if model.config.intermediate_size == 18944:
|
||||
# for qwen2-7b
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
|
||||
from ipex_llm.transformers.npu_models.convert_mp import convert_forward
|
||||
qwen2_attention_forward = generate_qwen2_attention_forward(
|
||||
max_seq_len=max_output_len,
|
||||
transpose_value=transpose_value_cache
|
||||
)
|
||||
convert_forward(model, Qwen2Attention, qwen2_attention_forward)
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
|
||||
convert_forward(model, Qwen2MLP, split_mlp_forward)
|
||||
deocderlayers = model.model.layers
|
||||
|
||||
while True:
|
||||
|
||||
result = input_queue.get()
|
||||
|
|
@ -1136,81 +1135,3 @@ def qwen2_casullm_forward(
|
|||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv
|
||||
import math
|
||||
|
||||
|
||||
def generate_qwen2_attention_forward(max_seq_len, transpose_value):
|
||||
def qwen2_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids)
|
||||
cache_kwargs = {"max_seq_len": max_seq_len, "transpose": transpose_value, }
|
||||
|
||||
if past_key_value is not None:
|
||||
if transpose_value:
|
||||
value_states = value_states.transpose(-1, -2)
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = None
|
||||
if query_states.size(2) == key_states.size(2):
|
||||
# first token
|
||||
from intel_npu_acceleration_library.functional import scaled_dot_product_attention
|
||||
attn_output = scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=q_len > 1 and bsz == 1,
|
||||
)
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states,
|
||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
# upcast attention to fp32
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||
training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return qwen2_attention_forward
|
||||
|
|
|
|||
Loading…
Reference in a new issue