Support qwen2-7b with fused decoderlayer optimization on NPU (#11912)
This commit is contained in:
parent
63ac5f64bb
commit
71f03dcc39
3 changed files with 157 additions and 42 deletions
|
|
@ -81,6 +81,7 @@ The example below shows how to run the **_optimized model implementations_** on
|
||||||
- [Llama2-7B](./llama.py)
|
- [Llama2-7B](./llama.py)
|
||||||
- [Llama3-8B](./llama.py)
|
- [Llama3-8B](./llama.py)
|
||||||
- [Qwen2-1.5B](./qwen2.py)
|
- [Qwen2-1.5B](./qwen2.py)
|
||||||
|
- [Qwen2-7B](./qwen2.py)
|
||||||
- [MiniCPM-1B](./minicpm.py)
|
- [MiniCPM-1B](./minicpm.py)
|
||||||
- [MiniCPM-2B](./minicpm.py)
|
- [MiniCPM-2B](./minicpm.py)
|
||||||
- [Baichuan2-7B](./baichuan2.py)
|
- [Baichuan2-7B](./baichuan2.py)
|
||||||
|
|
@ -95,6 +96,9 @@ python llama.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
# to run Qwen2-1.5B-Instruct
|
# to run Qwen2-1.5B-Instruct
|
||||||
python qwen2.py
|
python qwen2.py
|
||||||
|
|
||||||
|
# to run Qwen2-7B-Instruct
|
||||||
|
python qwen2.py --repo-id-or-model-path Qwen/Qwen2-7B-Instruct --inter-pp 4
|
||||||
|
|
||||||
# to run MiniCPM-1B-sft-bf16
|
# to run MiniCPM-1B-sft-bf16
|
||||||
python minicpm.py
|
python minicpm.py
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -107,12 +107,14 @@ def optimize_llm(
|
||||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||||
from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward
|
from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward
|
||||||
convert_forward(model, LlamaForCausalLM, llama2_casullm_forward)
|
convert_forward(model, LlamaForCausalLM, llama2_casullm_forward)
|
||||||
elif model.config.model_type == "qwen2" and model.config.intermediate_size == 8960:
|
elif model.config.model_type == "qwen2" and model.config.num_hidden_layers == 28:
|
||||||
# for qwen2-1.5B
|
# for qwen2-1.5B and qwen2-7B
|
||||||
if intra_pp is None:
|
if intra_pp is None:
|
||||||
intra_pp = 2
|
intra_pp = 2
|
||||||
if inter_pp is None:
|
if inter_pp is None:
|
||||||
inter_pp = 1
|
inter_pp = 4 if model.config.intermediate_size == 18944 else 1
|
||||||
|
if model.config.intermediate_size == 18944:
|
||||||
|
transpose_value_cache = False
|
||||||
|
|
||||||
from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
|
from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
|
||||||
from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner
|
from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner
|
||||||
|
|
|
||||||
|
|
@ -199,6 +199,25 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
|
|
||||||
self.compile()
|
self.compile()
|
||||||
|
|
||||||
|
def mlp(self, hidden_states):
|
||||||
|
mm1 = self.linear(
|
||||||
|
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
|
||||||
|
)
|
||||||
|
mm2 = self.linear(
|
||||||
|
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
|
||||||
|
) # type: ignore[attr-defined]
|
||||||
|
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
|
||||||
|
if self.intermediate_size == 18944:
|
||||||
|
# for qwen2-7b
|
||||||
|
hidden_states = self.linear(
|
||||||
|
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=np.int8
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = self.linear(
|
||||||
|
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
def build_decoder(
|
def build_decoder(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|
@ -734,6 +753,8 @@ def run_prefill(
|
||||||
input_layer_norm_weights = []
|
input_layer_norm_weights = []
|
||||||
post_attn_layernorm_weights = []
|
post_attn_layernorm_weights = []
|
||||||
layer_indexs = range(layer_start, layer_end)
|
layer_indexs = range(layer_start, layer_end)
|
||||||
|
if model.config.intermediate_size == 8960:
|
||||||
|
# for qwen2-1.5b
|
||||||
for layer_idx in layer_indexs:
|
for layer_idx in layer_indexs:
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
|
|
@ -782,6 +803,17 @@ def run_prefill(
|
||||||
print("finish creating all decode layers in prefill")
|
print("finish creating all decode layers in prefill")
|
||||||
result_queue.put("loading finish")
|
result_queue.put("loading finish")
|
||||||
|
|
||||||
|
if model.config.intermediate_size == 18944:
|
||||||
|
# for qwen2-7b
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
|
||||||
|
from ipex_llm.transformers.npu_models.convert_mp import convert_forward
|
||||||
|
qwen2_attention_forward = generate_qwen2_attention_forward(
|
||||||
|
max_seq_len=max_output_len,
|
||||||
|
transpose_value=transpose_value_cache
|
||||||
|
)
|
||||||
|
convert_forward(model, Qwen2Attention, qwen2_attention_forward)
|
||||||
|
deocderlayers = model.model.layers
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
result = input_queue.get()
|
result = input_queue.get()
|
||||||
|
|
@ -1053,3 +1085,80 @@ def qwen2_casullm_forward(
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def generate_qwen2_attention_forward(max_seq_len, transpose_value):
|
||||||
|
def qwen2_attention_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
|
||||||
|
self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
|
||||||
|
self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
|
cos, sin, position_ids)
|
||||||
|
|
||||||
|
cache_kwargs = {"max_seq_len": max_seq_len, "transpose": transpose_value, }
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
|
self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = None
|
||||||
|
if query_states.size(2) == key_states.size(2):
|
||||||
|
# first token
|
||||||
|
from intel_npu_acceleration_library.functional import scaled_dot_product_attention
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
is_causal=q_len > 1 and bsz == 1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(query_states,
|
||||||
|
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||||
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||||
|
training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
return qwen2_attention_forward
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue