Use merge_qkv to replace fused_qkv for llama2 (#11727)
* update 4.38 * support new versions * update * fix style * fix style * update rope * temp test sdpa * fix style * fix cpu ut
This commit is contained in:
parent
d2abc9711b
commit
00a5574c8a
4 changed files with 180 additions and 93 deletions
|
|
@ -744,6 +744,9 @@ def _optimize_pre(model, qtype=None):
|
|||
if model.config.model_type == "gemma2":
|
||||
from ipex_llm.transformers.models.gemma2 import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
if model.config.model_type == "llama":
|
||||
from ipex_llm.transformers.models.llama import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
|
||||
return model
|
||||
|
||||
|
|
@ -989,6 +992,10 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||
llama_attention_forward_4_41)
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
|
||||
llama_attention_forward_4_41)
|
||||
else:
|
||||
from ipex_llm.transformers.models.llama import llama_model_forward_4_38
|
||||
convert_forward(
|
||||
|
|
@ -999,6 +1006,10 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||
llama_attention_forward_4_38)
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
|
||||
llama_attention_forward_4_38)
|
||||
else:
|
||||
from ipex_llm.transformers.models.llama import llama_model_forward_4_36
|
||||
convert_forward(
|
||||
|
|
@ -1009,6 +1020,10 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||
llama_attention_forward_4_38)
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
|
||||
llama_attention_forward_4_38)
|
||||
else:
|
||||
# transformers version between 4.31.0 - 4.35.2
|
||||
convert_forward(
|
||||
|
|
|
|||
|
|
@ -49,12 +49,13 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
|||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \
|
||||
use_sdp_causal
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
|
||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path, get_q_proj_or_qkv_proj
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
from transformers.models.llama.modeling_llama import LlamaModel, LlamaAttention
|
||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS, FP4
|
||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base, fuse_mlp_base
|
||||
|
||||
try:
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
|
|
@ -66,6 +67,10 @@ from transformers import logging
|
|||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def merge_qkv(module: torch.nn.Module):
|
||||
return merge_qkv_base(module, LlamaAttention)
|
||||
|
||||
|
||||
def llama_decoding_fast_path_qtype_check(proj):
|
||||
# IQ2_XXS only can be used in Llama-like model
|
||||
qtype = getattr(proj, "qtype", None)
|
||||
|
|
@ -406,6 +411,9 @@ def fuse_qkv_weight_xetla(q_proj, k_proj, v_proj, qtype):
|
|||
|
||||
|
||||
def should_use_xetla_mm_qkv(self, device):
|
||||
if not hasattr(self, "q_proj"):
|
||||
# TODO: how to support xetla_mm_qkv for merged_qkv
|
||||
return False
|
||||
full_attn = self.q_proj.out_len == self.k_proj.out_len == self.v_proj.out_len
|
||||
supported_qtype = self.q_proj.qtype == SYM_INT4 and full_attn
|
||||
supported_qtype = supported_qtype or self.q_proj.qtype == FP8E5
|
||||
|
|
@ -428,7 +436,8 @@ def llama_attention_forward_4_31(
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||
if use_quantize_kv_cache(get_q_proj_or_qkv_proj(self), hidden_states,
|
||||
self.num_key_value_groups):
|
||||
forward_function = llama_attention_forward_4_31_quantized
|
||||
else:
|
||||
forward_function = llama_attention_forward_4_31_original
|
||||
|
|
@ -466,7 +475,7 @@ def llama_attention_forward_4_31_quantized(
|
|||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
|
||||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
bsz * q_len,
|
||||
|
|
@ -500,9 +509,16 @@ def llama_attention_forward_4_31_quantized(
|
|||
self.head_dim,
|
||||
self.rotary_emb.base,)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
if hasattr(self, "q_proj"):
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
else:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
||||
query_states, key_states, value_states = qkv.split([self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
self.num_key_value_heads], dim=2)
|
||||
|
||||
query_states = query_states.view(bsz, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
|
@ -516,12 +532,9 @@ def llama_attention_forward_4_31_quantized(
|
|||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
if use_fuse_rope:
|
||||
rope_theta = self.rotary_emb.base
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
"llama",
|
||||
rope_theta=rope_theta)
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
|
|
@ -604,7 +617,7 @@ def llama_attention_forward_4_31_original(
|
|||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
|
||||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
bsz * q_len,
|
||||
|
|
@ -654,7 +667,7 @@ def llama_attention_forward_4_31_original(
|
|||
for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
else:
|
||||
if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \
|
||||
if fp16_fusion_check(getattr(self, "q_proj", None), hidden_states, self.training) and \
|
||||
hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features:
|
||||
# only use mm_qkv_out on pvc for llama-7b
|
||||
if not hasattr(self, "qkv_proj_weight"):
|
||||
|
|
@ -692,9 +705,19 @@ def llama_attention_forward_4_31_original(
|
|||
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
|
||||
value_states = qkv_states[:, :, q_out_len + k_out_len:]
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
if hasattr(self, "q_proj"):
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
else:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads,
|
||||
self.head_dim)
|
||||
query_states, key_states, value_states = \
|
||||
qkv.split([self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
self.num_key_value_heads],
|
||||
dim=2)
|
||||
|
||||
query_states = query_states.view(bsz, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
|
@ -708,12 +731,9 @@ def llama_attention_forward_4_31_original(
|
|||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
if use_fuse_rope:
|
||||
rope_theta = self.rotary_emb.base
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
"llama",
|
||||
rope_theta=rope_theta)
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
|
|
@ -839,7 +859,7 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0])
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
|
||||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
bsz * q_len,
|
||||
|
|
@ -886,9 +906,18 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
if self.config.pretraining_tp > 1:
|
||||
invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet")
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
if hasattr(self, "q_proj"):
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
else:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads,
|
||||
self.head_dim)
|
||||
query_states, key_states, value_states = qkv.split([self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
self.num_key_value_heads],
|
||||
dim=2)
|
||||
|
||||
query_states = query_states.view(bsz, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
|
@ -902,12 +931,9 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value)
|
||||
|
||||
if use_fuse_rope:
|
||||
rope_theta = self.rotary_emb.base
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
"llama",
|
||||
rope_theta=rope_theta)
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
|
|
@ -1030,7 +1056,8 @@ def llama_attention_forward_4_41(
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||
if use_quantize_kv_cache(get_q_proj_or_qkv_proj(self), hidden_states,
|
||||
self.num_key_value_groups):
|
||||
forward_function = llama_attention_forward_4_41_quantized
|
||||
else:
|
||||
forward_function = llama_attention_forward_4_41_original
|
||||
|
|
@ -1069,7 +1096,7 @@ def llama_attention_forward_4_41_quantized(
|
|||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
|
||||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
bsz * q_len,
|
||||
|
|
@ -1098,9 +1125,16 @@ def llama_attention_forward_4_41_quantized(
|
|||
self.head_dim,
|
||||
self.rotary_emb.base,)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
if hasattr(self, "q_proj"):
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
else:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
||||
query_states, key_states, value_states = qkv.split([self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
self.num_key_value_heads], dim=2)
|
||||
|
||||
query_states = query_states.view(bsz, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
|
@ -1122,12 +1156,9 @@ def llama_attention_forward_4_41_quantized(
|
|||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
if use_fuse_rope:
|
||||
rope_theta = self.rotary_emb.base
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
"llama",
|
||||
rope_theta=rope_theta)
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
if cache_position is not None:
|
||||
# for transformers 4.38.0
|
||||
|
|
@ -1301,7 +1332,7 @@ def llama_attention_forward_4_41_original(
|
|||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
|
||||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
bsz * q_len,
|
||||
|
|
@ -1360,7 +1391,7 @@ def llama_attention_forward_4_41_original(
|
|||
for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
else:
|
||||
if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \
|
||||
if fp16_fusion_check(getattr(self, "q_proj", None), hidden_states, self.training) and \
|
||||
hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features:
|
||||
# only use mm_qkv_out on pvc for llama-7b
|
||||
if not hasattr(self, "qkv_proj_weight"):
|
||||
|
|
@ -1399,9 +1430,20 @@ def llama_attention_forward_4_41_original(
|
|||
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
|
||||
value_states = qkv_states[:, :, q_out_len + k_out_len:]
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
if hasattr(self, "q_proj"):
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
else:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len,
|
||||
self.num_heads + 2 * self.num_key_value_heads,
|
||||
self.head_dim)
|
||||
query_states, key_states, value_states = \
|
||||
qkv.split([self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
self.num_key_value_heads],
|
||||
dim=2)
|
||||
|
||||
query_states = query_states.view(bsz, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
|
@ -1421,12 +1463,9 @@ def llama_attention_forward_4_41_original(
|
|||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
if use_fuse_rope:
|
||||
rope_theta = self.rotary_emb.base
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
"llama",
|
||||
rope_theta=rope_theta)
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
if cache_position is not None:
|
||||
# for transformers 4.38.0
|
||||
|
|
@ -1582,7 +1621,8 @@ def llama_attention_forward_4_38(
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||
if use_quantize_kv_cache(get_q_proj_or_qkv_proj(self), hidden_states,
|
||||
self.num_key_value_groups):
|
||||
forward_function = llama_attention_forward_4_38_quantized
|
||||
else:
|
||||
forward_function = llama_attention_forward_4_38_original
|
||||
|
|
@ -1621,7 +1661,7 @@ def llama_attention_forward_4_38_quantized(
|
|||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
|
||||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
bsz * q_len,
|
||||
|
|
@ -1650,9 +1690,16 @@ def llama_attention_forward_4_38_quantized(
|
|||
self.head_dim,
|
||||
self.rotary_emb.base,)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
if hasattr(self, "q_proj"):
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
else:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
||||
query_states, key_states, value_states = qkv.split([self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
self.num_key_value_heads], dim=2)
|
||||
|
||||
query_states = query_states.view(bsz, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
|
@ -1674,12 +1721,9 @@ def llama_attention_forward_4_38_quantized(
|
|||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
if use_fuse_rope:
|
||||
rope_theta = self.rotary_emb.base
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
"llama",
|
||||
rope_theta=rope_theta)
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
if cache_position is not None:
|
||||
# for transformers 4.38.0
|
||||
|
|
@ -1853,7 +1897,7 @@ def llama_attention_forward_4_38_original(
|
|||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
|
||||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
bsz * q_len,
|
||||
|
|
@ -1911,7 +1955,7 @@ def llama_attention_forward_4_38_original(
|
|||
for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
else:
|
||||
if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \
|
||||
if fp16_fusion_check(getattr(self, "q_proj", None), hidden_states, self.training) and \
|
||||
hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features:
|
||||
# only use mm_qkv_out on pvc for llama-7b
|
||||
if not hasattr(self, "qkv_proj_weight"):
|
||||
|
|
@ -1950,9 +1994,20 @@ def llama_attention_forward_4_38_original(
|
|||
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
|
||||
value_states = qkv_states[:, :, q_out_len + k_out_len:]
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
if hasattr(self, "q_proj"):
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
else:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len,
|
||||
self.num_heads + 2 * self.num_key_value_heads,
|
||||
self.head_dim)
|
||||
query_states, key_states, value_states = \
|
||||
qkv.split([self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
self.num_key_value_heads],
|
||||
dim=2)
|
||||
|
||||
query_states = query_states.view(bsz, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
|
@ -1972,12 +2027,9 @@ def llama_attention_forward_4_38_original(
|
|||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
if use_fuse_rope:
|
||||
rope_theta = self.rotary_emb.base
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
"llama",
|
||||
rope_theta=rope_theta)
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
if cache_position is not None:
|
||||
# for transformers 4.38.0
|
||||
|
|
@ -2413,9 +2465,16 @@ def llama_attention_fast_forward(
|
|||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
if hasattr(self, "q_proj"):
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
else:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
||||
query_states, key_states, value_states = qkv.split([self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
self.num_key_value_heads], dim=2)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
|
||||
|
|
|
|||
|
|
@ -377,6 +377,8 @@ def use_decoding_fast_path(proj,
|
|||
enough_kv_room,
|
||||
bs,
|
||||
qtype_check=decoding_fast_path_qtype_check):
|
||||
if proj is None:
|
||||
return False
|
||||
device = get_xpu_device_type(proj.weight)
|
||||
if not qtype_check(proj):
|
||||
return False
|
||||
|
|
@ -419,6 +421,8 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
|
|||
|
||||
def fp16_fusion_check(proj, x, training):
|
||||
# only use fp16 fusion on PVC inference
|
||||
if proj is None:
|
||||
return False
|
||||
if not hasattr(proj, "qtype"):
|
||||
return False
|
||||
if proj.qtype != ggml_tensor_qtype["fp16"]:
|
||||
|
|
@ -491,3 +495,11 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int):
|
|||
)
|
||||
else:
|
||||
return x.device.type == 'xpu' and use_compress_kv == "1"
|
||||
|
||||
|
||||
def get_q_proj_or_qkv_proj(self):
|
||||
if hasattr(self, "q_proj"):
|
||||
proj = self.q_proj
|
||||
elif hasattr(self, "qkv_proj"):
|
||||
proj = self.qkv_proj
|
||||
return proj
|
||||
|
|
|
|||
|
|
@ -152,20 +152,21 @@ def test_optimize_model(Model, Tokenizer, model_path, prompt):
|
|||
tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
|
||||
model = Model.from_pretrained(model_path,
|
||||
load_in_4bit=True,
|
||||
optimize_model=False,
|
||||
trust_remote_code=True)
|
||||
logits_base_model = (model(input_ids)).logits
|
||||
with torch.inference_mode():
|
||||
model = Model.from_pretrained(model_path,
|
||||
load_in_4bit=True,
|
||||
optimize_model=False,
|
||||
trust_remote_code=True)
|
||||
logits_base_model = (model(input_ids)).logits
|
||||
|
||||
model = Model.from_pretrained(model_path,
|
||||
load_in_4bit=True,
|
||||
optimize_model=True,
|
||||
trust_remote_code=True)
|
||||
logits_optimized_model = (model(input_ids)).logits
|
||||
diff = abs(logits_base_model - logits_optimized_model).flatten()
|
||||
model = Model.from_pretrained(model_path,
|
||||
load_in_4bit=True,
|
||||
optimize_model=True,
|
||||
trust_remote_code=True)
|
||||
logits_optimized_model = (model(input_ids)).logits
|
||||
diff = abs(logits_base_model - logits_optimized_model).flatten()
|
||||
|
||||
assert any(diff) is False
|
||||
assert any(diff) is False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
Loading…
Reference in a new issue