parent
0e53f20edb
commit
c5e8b90c8d
2 changed files with 139 additions and 6 deletions
|
|
@ -1243,13 +1243,20 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.qwen import qwen_attention_forward
|
||||
from ipex_llm.transformers.models.qwen import qwen_attention_forward_registered
|
||||
from ipex_llm.transformers.models.qwen import qwen_mlp_forward
|
||||
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
||||
from ipex_llm.transformers.models.qwen import qwen_model_forward
|
||||
convert_forward(model,
|
||||
module.QWenAttention,
|
||||
qwen_attention_forward
|
||||
)
|
||||
if model.config.max_position_embeddings == 8192:
|
||||
convert_forward(model,
|
||||
module.QWenAttention,
|
||||
qwen_attention_forward_registered
|
||||
)
|
||||
else:
|
||||
convert_forward(model,
|
||||
module.QWenAttention,
|
||||
qwen_attention_forward
|
||||
)
|
||||
convert_forward(model,
|
||||
module.RMSNorm,
|
||||
chatglm_rms_norm_forward)
|
||||
|
|
@ -1513,7 +1520,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
from ipex_llm.transformers.models.starcoder2 import model_forward
|
||||
convert_forward(model, module.Starcoder2Attention, attention_forward)
|
||||
convert_forward(model, module.Starcoder2Model, model_forward)
|
||||
elif model.config.model_type == 'phi':
|
||||
elif model.config.model_type in ["phi3", "phi3_v"]:
|
||||
# for phi-2
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
@ -1521,7 +1528,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
from ipex_llm.transformers.models.phi import model_forward
|
||||
convert_forward(model, module.PhiAttention, attention_forward)
|
||||
convert_forward(model, module.PhiModel, model_forward)
|
||||
elif model.config.model_type in ["phi3", "phi3_v"]:
|
||||
elif model.config.model_type == "phi3":
|
||||
# for phi-3
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
|
|||
|
|
@ -180,6 +180,132 @@ def qwen_attention_forward(
|
|||
return attn_output, past_key_value
|
||||
|
||||
|
||||
def qwen_attention_forward_registered(
|
||||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
|
||||
registered_causal_mask: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# invalidInputError(not self.use_flash_attn and not self.use_cache_quantization,
|
||||
# "flash attn and kv_cache quantization are not supported")
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
past_key_value = (None if layer_past is None
|
||||
else (layer_past[0].transpose(1, 2), layer_past[1].transpose(1, 2)))
|
||||
|
||||
qkv = self.c_attn(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
|
||||
qkv = qkv.transpose(1, 2)
|
||||
query_states, key_states, value_states = qkv.split([self.num_heads,
|
||||
self.num_heads,
|
||||
self.num_heads], dim=1)
|
||||
|
||||
kv_seq_len = key_states.shape[2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[2]
|
||||
|
||||
# IPEX-LLM OPT: fuse rope
|
||||
position_ids = rotary_pos_emb_list[-1] # the last one is posisiton_ids
|
||||
inv_freq = rotary_pos_emb_list[-2]
|
||||
rotary_pos_emb_list = rotary_pos_emb_list[:-2]
|
||||
invalidInputError(len(rotary_pos_emb_list) == 1,
|
||||
"rotary_pos_emb_list's length cannot be larger than 1")
|
||||
use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
|
||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||
if use_fuse_rope:
|
||||
rot_dim = rotary_pos_emb[0].size(-1)
|
||||
import linear_q4_0
|
||||
linear_q4_0.rotary_half_inplaced(inv_freq, position_ids,
|
||||
query_states[..., :rot_dim], key_states[..., :rot_dim])
|
||||
else:
|
||||
rotary_pos_emb = [i[:, -q_len:, :, :].transpose(1, 2) for i in rotary_pos_emb]
|
||||
query_states = apply_rotary_pos_emb(query_states, rotary_pos_emb)
|
||||
key_states = apply_rotary_pos_emb(key_states, rotary_pos_emb)
|
||||
|
||||
if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training:
|
||||
seq_start = kv_seq_len - q_len
|
||||
seq_end = kv_seq_len
|
||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].transpose(1, 2)
|
||||
query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)
|
||||
|
||||
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
||||
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states)
|
||||
key_states, value_states = update_past_key_value(
|
||||
past_key_value, key_states, value_states,
|
||||
kv_seq_len, use_quantize_kv, device
|
||||
)
|
||||
past_key_value = (key_states.transpose(1, 2),
|
||||
value_states.transpose(1, 2)) if use_cache else None
|
||||
|
||||
# IPEX-LLM OPT: sdp
|
||||
attn_weights = None
|
||||
if not self.training and not hidden_states.requires_grad and \
|
||||
use_flash_attention(query_states, key_states, attention_mask):
|
||||
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
|
||||
key_states.to(dtype=torch.float16),
|
||||
value_states.to(dtype=torch.float16),
|
||||
is_causal=True).to(hidden_states.dtype)
|
||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||
import linear_q4_0
|
||||
if use_quantize_kv:
|
||||
attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states)
|
||||
else:
|
||||
attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states)
|
||||
else:
|
||||
if q_len > 1:
|
||||
causal_mask = registered_causal_mask[
|
||||
:, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len
|
||||
]
|
||||
attention_mask = torch.zeros(causal_mask.shape, dtype=query_states.dtype,
|
||||
device=query_states.device)
|
||||
attention_mask.masked_fill_(causal_mask.logical_not(),
|
||||
torch.finfo(attention_mask.dtype).min)
|
||||
attention_mask = attention_mask.expand([bsz, -1, -1, -1])
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||
import linear_q4_0
|
||||
if use_quantize_kv:
|
||||
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
else:
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
else:
|
||||
if use_quantize_kv:
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
attn_weights = torch.matmul(query_states,
|
||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
if self.softmax_in_fp32:
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(
|
||||
value_states.dtype)
|
||||
else:
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.c_proj(attn_output)
|
||||
|
||||
if output_attentions:
|
||||
return attn_output, past_key_value, attn_weights
|
||||
else:
|
||||
return attn_output, past_key_value
|
||||
|
||||
|
||||
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
qtype = getattr(self.w1, "qtype", None)
|
||||
|
|
|
|||
Loading…
Reference in a new issue