Use merge_qkv to replace fused_qkv for llama2 (#11727)

* update 4.38

* support new versions

* update

* fix style

* fix style

* update rope

* temp test sdpa

* fix style

* fix cpu ut
This commit is contained in:
Ruonan Wang 2024-08-07 13:04:01 +03:00 committed by GitHub
parent d2abc9711b
commit 00a5574c8a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 180 additions and 93 deletions

View file

@ -744,6 +744,9 @@ def _optimize_pre(model, qtype=None):
if model.config.model_type == "gemma2": if model.config.model_type == "gemma2":
from ipex_llm.transformers.models.gemma2 import merge_qkv from ipex_llm.transformers.models.gemma2 import merge_qkv
model.apply(merge_qkv) model.apply(merge_qkv)
if model.config.model_type == "llama":
from ipex_llm.transformers.models.llama import merge_qkv
model.apply(merge_qkv)
return model return model
@ -989,6 +992,10 @@ def _optimize_post(model, lightweight_bmm=False):
model, model,
transformers.models.llama.modeling_llama.LlamaAttention, transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_41) llama_attention_forward_4_41)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
llama_attention_forward_4_41)
else: else:
from ipex_llm.transformers.models.llama import llama_model_forward_4_38 from ipex_llm.transformers.models.llama import llama_model_forward_4_38
convert_forward( convert_forward(
@ -999,6 +1006,10 @@ def _optimize_post(model, lightweight_bmm=False):
model, model,
transformers.models.llama.modeling_llama.LlamaAttention, transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_38) llama_attention_forward_4_38)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
llama_attention_forward_4_38)
else: else:
from ipex_llm.transformers.models.llama import llama_model_forward_4_36 from ipex_llm.transformers.models.llama import llama_model_forward_4_36
convert_forward( convert_forward(
@ -1009,6 +1020,10 @@ def _optimize_post(model, lightweight_bmm=False):
model, model,
transformers.models.llama.modeling_llama.LlamaAttention, transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_38) llama_attention_forward_4_38)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
llama_attention_forward_4_38)
else: else:
# transformers version between 4.31.0 - 4.35.2 # transformers version between 4.31.0 - 4.35.2
convert_forward( convert_forward(

View file

@ -49,12 +49,13 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \ from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \
use_sdp_causal use_sdp_causal
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.transformers.models.utils import use_decoding_fast_path, get_q_proj_or_qkv_proj
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel from transformers.models.llama.modeling_llama import LlamaModel, LlamaAttention
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS, FP4 from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS, FP4
from ipex_llm.ggml.quantize import ggml_tensor_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.common import merge_qkv_base, fuse_mlp_base
try: try:
from transformers.cache_utils import Cache, DynamicCache from transformers.cache_utils import Cache, DynamicCache
@ -66,6 +67,10 @@ from transformers import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def merge_qkv(module: torch.nn.Module):
return merge_qkv_base(module, LlamaAttention)
def llama_decoding_fast_path_qtype_check(proj): def llama_decoding_fast_path_qtype_check(proj):
# IQ2_XXS only can be used in Llama-like model # IQ2_XXS only can be used in Llama-like model
qtype = getattr(proj, "qtype", None) qtype = getattr(proj, "qtype", None)
@ -406,6 +411,9 @@ def fuse_qkv_weight_xetla(q_proj, k_proj, v_proj, qtype):
def should_use_xetla_mm_qkv(self, device): def should_use_xetla_mm_qkv(self, device):
if not hasattr(self, "q_proj"):
# TODO: how to support xetla_mm_qkv for merged_qkv
return False
full_attn = self.q_proj.out_len == self.k_proj.out_len == self.v_proj.out_len full_attn = self.q_proj.out_len == self.k_proj.out_len == self.v_proj.out_len
supported_qtype = self.q_proj.qtype == SYM_INT4 and full_attn supported_qtype = self.q_proj.qtype == SYM_INT4 and full_attn
supported_qtype = supported_qtype or self.q_proj.qtype == FP8E5 supported_qtype = supported_qtype or self.q_proj.qtype == FP8E5
@ -428,7 +436,8 @@ def llama_attention_forward_4_31(
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups): if use_quantize_kv_cache(get_q_proj_or_qkv_proj(self), hidden_states,
self.num_key_value_groups):
forward_function = llama_attention_forward_4_31_quantized forward_function = llama_attention_forward_4_31_quantized
else: else:
forward_function = llama_attention_forward_4_31_original forward_function = llama_attention_forward_4_31_original
@ -466,7 +475,7 @@ def llama_attention_forward_4_31_quantized(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1 no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj, decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
use_fuse_rope, use_fuse_rope,
enough_kv_room, enough_kv_room,
bsz * q_len, bsz * q_len,
@ -500,9 +509,16 @@ def llama_attention_forward_4_31_quantized(
self.head_dim, self.head_dim,
self.rotary_emb.base,) self.rotary_emb.base,)
else: else:
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=2)
query_states = query_states.view(bsz, q_len, query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2) self.num_heads, self.head_dim).transpose(1, 2)
@ -516,12 +532,9 @@ def llama_attention_forward_4_31_quantized(
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
if use_fuse_rope: if use_fuse_rope:
rope_theta = self.rotary_emb.base import xe_addons
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
key_states, query_states, key_states)
position_ids,
"llama",
rope_theta=rope_theta)
else: else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -604,7 +617,7 @@ def llama_attention_forward_4_31_original(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1 no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj, decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
use_fuse_rope, use_fuse_rope,
enough_kv_room, enough_kv_room,
bsz * q_len, bsz * q_len,
@ -654,7 +667,7 @@ def llama_attention_forward_4_31_original(
for i in range(self.config.pretraining_tp)] for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1) value_states = torch.cat(value_states, dim=-1)
else: else:
if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \ if fp16_fusion_check(getattr(self, "q_proj", None), hidden_states, self.training) and \
hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features: hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features:
# only use mm_qkv_out on pvc for llama-7b # only use mm_qkv_out on pvc for llama-7b
if not hasattr(self, "qkv_proj_weight"): if not hasattr(self, "qkv_proj_weight"):
@ -692,9 +705,19 @@ def llama_attention_forward_4_31_original(
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len] key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
value_states = qkv_states[:, :, q_out_len + k_out_len:] value_states = qkv_states[:, :, q_out_len + k_out_len:]
else: else:
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads,
self.head_dim)
query_states, key_states, value_states = \
qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads],
dim=2)
query_states = query_states.view(bsz, q_len, query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2) self.num_heads, self.head_dim).transpose(1, 2)
@ -708,12 +731,9 @@ def llama_attention_forward_4_31_original(
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
if use_fuse_rope: if use_fuse_rope:
rope_theta = self.rotary_emb.base import xe_addons
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
key_states, query_states, key_states)
position_ids,
"llama",
rope_theta=rope_theta)
else: else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -839,7 +859,7 @@ def llama_attention_selective_batching_forward_4_31(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0]) enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0])
no_tp = not self.config.pretraining_tp > 1 no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj, decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
use_fuse_rope, use_fuse_rope,
enough_kv_room, enough_kv_room,
bsz * q_len, bsz * q_len,
@ -886,9 +906,18 @@ def llama_attention_selective_batching_forward_4_31(
if self.config.pretraining_tp > 1: if self.config.pretraining_tp > 1:
invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet") invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet")
else: else:
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads,
self.head_dim)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads],
dim=2)
query_states = query_states.view(bsz, q_len, query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2) self.num_heads, self.head_dim).transpose(1, 2)
@ -902,12 +931,9 @@ def llama_attention_selective_batching_forward_4_31(
kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value) kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value)
if use_fuse_rope: if use_fuse_rope:
rope_theta = self.rotary_emb.base import xe_addons
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
key_states, query_states, key_states)
position_ids,
"llama",
rope_theta=rope_theta)
else: else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -1030,7 +1056,8 @@ def llama_attention_forward_4_41(
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs **kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups): if use_quantize_kv_cache(get_q_proj_or_qkv_proj(self), hidden_states,
self.num_key_value_groups):
forward_function = llama_attention_forward_4_41_quantized forward_function = llama_attention_forward_4_41_quantized
else: else:
forward_function = llama_attention_forward_4_41_original forward_function = llama_attention_forward_4_41_original
@ -1069,7 +1096,7 @@ def llama_attention_forward_4_41_quantized(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1 no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj, decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
use_fuse_rope, use_fuse_rope,
enough_kv_room, enough_kv_room,
bsz * q_len, bsz * q_len,
@ -1098,9 +1125,16 @@ def llama_attention_forward_4_41_quantized(
self.head_dim, self.head_dim,
self.rotary_emb.base,) self.rotary_emb.base,)
else: else:
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=2)
query_states = query_states.view(bsz, q_len, query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2) self.num_heads, self.head_dim).transpose(1, 2)
@ -1122,12 +1156,9 @@ def llama_attention_forward_4_41_quantized(
) )
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope: if use_fuse_rope:
rope_theta = self.rotary_emb.base import xe_addons
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
key_states, query_states, key_states)
position_ids,
"llama",
rope_theta=rope_theta)
else: else:
if cache_position is not None: if cache_position is not None:
# for transformers 4.38.0 # for transformers 4.38.0
@ -1301,7 +1332,7 @@ def llama_attention_forward_4_41_original(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1 no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj, decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
use_fuse_rope, use_fuse_rope,
enough_kv_room, enough_kv_room,
bsz * q_len, bsz * q_len,
@ -1360,7 +1391,7 @@ def llama_attention_forward_4_41_original(
for i in range(self.config.pretraining_tp)] for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1) value_states = torch.cat(value_states, dim=-1)
else: else:
if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \ if fp16_fusion_check(getattr(self, "q_proj", None), hidden_states, self.training) and \
hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features: hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features:
# only use mm_qkv_out on pvc for llama-7b # only use mm_qkv_out on pvc for llama-7b
if not hasattr(self, "qkv_proj_weight"): if not hasattr(self, "qkv_proj_weight"):
@ -1399,9 +1430,20 @@ def llama_attention_forward_4_41_original(
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len] key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
value_states = qkv_states[:, :, q_out_len + k_out_len:] value_states = qkv_states[:, :, q_out_len + k_out_len:]
else: else:
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len,
self.num_heads + 2 * self.num_key_value_heads,
self.head_dim)
query_states, key_states, value_states = \
qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads],
dim=2)
query_states = query_states.view(bsz, q_len, query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2) self.num_heads, self.head_dim).transpose(1, 2)
@ -1421,12 +1463,9 @@ def llama_attention_forward_4_41_original(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope: if use_fuse_rope:
rope_theta = self.rotary_emb.base import xe_addons
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
key_states, query_states, key_states)
position_ids,
"llama",
rope_theta=rope_theta)
else: else:
if cache_position is not None: if cache_position is not None:
# for transformers 4.38.0 # for transformers 4.38.0
@ -1582,7 +1621,8 @@ def llama_attention_forward_4_38(
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs **kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups): if use_quantize_kv_cache(get_q_proj_or_qkv_proj(self), hidden_states,
self.num_key_value_groups):
forward_function = llama_attention_forward_4_38_quantized forward_function = llama_attention_forward_4_38_quantized
else: else:
forward_function = llama_attention_forward_4_38_original forward_function = llama_attention_forward_4_38_original
@ -1621,7 +1661,7 @@ def llama_attention_forward_4_38_quantized(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1 no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj, decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
use_fuse_rope, use_fuse_rope,
enough_kv_room, enough_kv_room,
bsz * q_len, bsz * q_len,
@ -1650,9 +1690,16 @@ def llama_attention_forward_4_38_quantized(
self.head_dim, self.head_dim,
self.rotary_emb.base,) self.rotary_emb.base,)
else: else:
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=2)
query_states = query_states.view(bsz, q_len, query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2) self.num_heads, self.head_dim).transpose(1, 2)
@ -1674,12 +1721,9 @@ def llama_attention_forward_4_38_quantized(
) )
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope: if use_fuse_rope:
rope_theta = self.rotary_emb.base import xe_addons
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
key_states, query_states, key_states)
position_ids,
"llama",
rope_theta=rope_theta)
else: else:
if cache_position is not None: if cache_position is not None:
# for transformers 4.38.0 # for transformers 4.38.0
@ -1853,7 +1897,7 @@ def llama_attention_forward_4_38_original(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1 no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj, decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
use_fuse_rope, use_fuse_rope,
enough_kv_room, enough_kv_room,
bsz * q_len, bsz * q_len,
@ -1911,7 +1955,7 @@ def llama_attention_forward_4_38_original(
for i in range(self.config.pretraining_tp)] for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1) value_states = torch.cat(value_states, dim=-1)
else: else:
if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \ if fp16_fusion_check(getattr(self, "q_proj", None), hidden_states, self.training) and \
hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features: hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features:
# only use mm_qkv_out on pvc for llama-7b # only use mm_qkv_out on pvc for llama-7b
if not hasattr(self, "qkv_proj_weight"): if not hasattr(self, "qkv_proj_weight"):
@ -1950,9 +1994,20 @@ def llama_attention_forward_4_38_original(
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len] key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
value_states = qkv_states[:, :, q_out_len + k_out_len:] value_states = qkv_states[:, :, q_out_len + k_out_len:]
else: else:
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len,
self.num_heads + 2 * self.num_key_value_heads,
self.head_dim)
query_states, key_states, value_states = \
qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads],
dim=2)
query_states = query_states.view(bsz, q_len, query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2) self.num_heads, self.head_dim).transpose(1, 2)
@ -1972,12 +2027,9 @@ def llama_attention_forward_4_38_original(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope: if use_fuse_rope:
rope_theta = self.rotary_emb.base import xe_addons
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
key_states, query_states, key_states)
position_ids,
"llama",
rope_theta=rope_theta)
else: else:
if cache_position is not None: if cache_position is not None:
# for transformers 4.38.0 # for transformers 4.38.0
@ -2413,9 +2465,16 @@ def llama_attention_fast_forward(
value_states = torch.cat(value_states, dim=-1) value_states = torch.cat(value_states, dim=-1)
else: else:
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, key_states = key_states.view(bsz, q_len, self.num_key_value_heads,

View file

@ -377,6 +377,8 @@ def use_decoding_fast_path(proj,
enough_kv_room, enough_kv_room,
bs, bs,
qtype_check=decoding_fast_path_qtype_check): qtype_check=decoding_fast_path_qtype_check):
if proj is None:
return False
device = get_xpu_device_type(proj.weight) device = get_xpu_device_type(proj.weight)
if not qtype_check(proj): if not qtype_check(proj):
return False return False
@ -419,6 +421,8 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
def fp16_fusion_check(proj, x, training): def fp16_fusion_check(proj, x, training):
# only use fp16 fusion on PVC inference # only use fp16 fusion on PVC inference
if proj is None:
return False
if not hasattr(proj, "qtype"): if not hasattr(proj, "qtype"):
return False return False
if proj.qtype != ggml_tensor_qtype["fp16"]: if proj.qtype != ggml_tensor_qtype["fp16"]:
@ -491,3 +495,11 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int):
) )
else: else:
return x.device.type == 'xpu' and use_compress_kv == "1" return x.device.type == 'xpu' and use_compress_kv == "1"
def get_q_proj_or_qkv_proj(self):
if hasattr(self, "q_proj"):
proj = self.q_proj
elif hasattr(self, "qkv_proj"):
proj = self.qkv_proj
return proj

View file

@ -152,6 +152,7 @@ def test_optimize_model(Model, Tokenizer, model_path, prompt):
tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True)
input_ids = tokenizer.encode(prompt, return_tensors="pt") input_ids = tokenizer.encode(prompt, return_tensors="pt")
with torch.inference_mode():
model = Model.from_pretrained(model_path, model = Model.from_pretrained(model_path,
load_in_4bit=True, load_in_4bit=True,
optimize_model=False, optimize_model=False,