refactor mistral and phi3 (#12605)
This commit is contained in:
parent
45f8f72a28
commit
073f936c37
5 changed files with 99 additions and 1367 deletions
|
|
@ -1031,6 +1031,9 @@ def _optimize_pre(model, qtype=None):
|
||||||
elif model.config.model_type == "mllama":
|
elif model.config.model_type == "mllama":
|
||||||
from ipex_llm.transformers.models.mllama import merge_qkv
|
from ipex_llm.transformers.models.mllama import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
|
elif model.config.model_type == "mistral":
|
||||||
|
from ipex_llm.transformers.models.mistral import merge_qkv
|
||||||
|
model.apply(merge_qkv)
|
||||||
elif model.config.model_type == "minicpm":
|
elif model.config.model_type == "minicpm":
|
||||||
from ipex_llm.transformers.models.minicpm import merge_qkv, apply_residual_scale
|
from ipex_llm.transformers.models.minicpm import merge_qkv, apply_residual_scale
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
|
|
@ -1901,43 +1904,17 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
else:
|
else:
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
if version.parse(trans_version) >= version.parse("4.36.0"):
|
|
||||||
from ipex_llm.transformers.models.mistral import mistral_model_forward_4_36
|
from ipex_llm.transformers.models.mistral import mistral_model_forward
|
||||||
if version.parse(trans_version) >= version.parse("4.39.0"):
|
|
||||||
from ipex_llm.transformers.models.mistral import \
|
|
||||||
mistral_attention_forward_4_39
|
|
||||||
convert_forward(model,
|
|
||||||
module.MistralAttention,
|
|
||||||
mistral_attention_forward_4_39
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from ipex_llm.transformers.models.mistral import mistral_attention_forward_4_36
|
|
||||||
convert_forward(model,
|
|
||||||
module.MistralAttention,
|
|
||||||
mistral_attention_forward_4_36
|
|
||||||
)
|
|
||||||
convert_forward(model,
|
|
||||||
module.MistralModel,
|
|
||||||
mistral_model_forward_4_36
|
|
||||||
)
|
|
||||||
convert_forward(model,
|
|
||||||
module.MistralRMSNorm,
|
|
||||||
llama_rms_norm_forward)
|
|
||||||
convert_forward(model,
|
|
||||||
module.MistralMLP,
|
|
||||||
llama_mlp_forward)
|
|
||||||
else:
|
|
||||||
from ipex_llm.transformers.models.mistral import mistral_attention_forward
|
from ipex_llm.transformers.models.mistral import mistral_attention_forward
|
||||||
convert_forward(model,
|
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||||
module.MistralAttention,
|
from ipex_llm.transformers.models.common import mlp_silu_forward
|
||||||
mistral_attention_forward
|
|
||||||
)
|
convert_forward(model, module.MistralModel, mistral_model_forward)
|
||||||
convert_forward(model,
|
convert_forward(model, module.MistralAttention, mistral_attention_forward)
|
||||||
module.MistralRMSNorm,
|
convert_forward(model, module.MistralSdpaAttention, mistral_attention_forward)
|
||||||
llama_rms_norm_forward)
|
convert_forward(model, module.MistralRMSNorm, rms_norm_forward)
|
||||||
convert_forward(model,
|
convert_forward(model, module.MistralMLP, mlp_silu_forward)
|
||||||
module.MistralMLP,
|
|
||||||
llama_mlp_forward)
|
|
||||||
elif model.config.model_type == "gemma":
|
elif model.config.model_type == "gemma":
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
@ -2078,8 +2055,8 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model, module.Phi3Attention, attention_forward)
|
convert_forward(model, module.Phi3Attention, attention_forward)
|
||||||
from ipex_llm.transformers.models.phi3 import mlp_forward
|
from ipex_llm.transformers.models.phi3 import mlp_forward
|
||||||
convert_forward(model, module.Phi3MLP, mlp_forward)
|
convert_forward(model, module.Phi3MLP, mlp_forward)
|
||||||
from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
|
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||||
convert_forward(model, module.Phi3RMSNorm, phi3_rms_norm_forward)
|
convert_forward(model, module.Phi3RMSNorm, rms_norm_forward)
|
||||||
if model.config.model_type == "phi3":
|
if model.config.model_type == "phi3":
|
||||||
from ipex_llm.transformers.models.phi3 import phi3_model_forward_wrapper
|
from ipex_llm.transformers.models.phi3 import phi3_model_forward_wrapper
|
||||||
model_forward = phi3_model_forward_wrapper(module.Phi3Model.forward)
|
model_forward = phi3_model_forward_wrapper(module.Phi3Model.forward)
|
||||||
|
|
|
||||||
|
|
@ -281,8 +281,13 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
|
||||||
key = repeat_kv(key, n_heads // n_kv_heads)
|
key = repeat_kv(key, n_heads // n_kv_heads)
|
||||||
value = repeat_kv(value, n_heads // n_kv_heads)
|
value = repeat_kv(value, n_heads // n_kv_heads)
|
||||||
|
|
||||||
|
if is_causal and mask is None:
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query, key, value, mask, is_causal=is_causal, scale=scale
|
query, key, value, is_causal=is_causal, scale=scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query, key, value, mask, scale=scale
|
||||||
)
|
)
|
||||||
attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug
|
attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -35,12 +35,12 @@ import os
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import warnings
|
import warnings
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from ipex_llm.transformers.models.common import attention_softmax
|
from ipex_llm.transformers.models.common import attention_softmax
|
||||||
|
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, get_compresskv_attn_mask
|
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
|
from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
|
||||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, \
|
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, \
|
||||||
|
|
@ -149,28 +149,20 @@ def attention_forward(
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, None)
|
self.layer_idx, None)
|
||||||
|
|
||||||
|
attn_weights = None
|
||||||
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||||
# [CompressKV]
|
attn_output = scaled_dot_product_attention(
|
||||||
if use_compresskv:
|
query_states, key_states, value_states,
|
||||||
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
attention_mask, False
|
||||||
import xe_addons
|
)
|
||||||
if use_quantizekv:
|
|
||||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
|
||||||
attention_mask)
|
|
||||||
else:
|
|
||||||
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
|
||||||
attention_mask)
|
|
||||||
elif (
|
elif (
|
||||||
use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training)
|
use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training)
|
||||||
and os.environ.get("IPEX_LLM_LOW_MEM", "0") == "1"
|
and os.environ.get("IPEX_LLM_LOW_MEM", "0") == "1"
|
||||||
):
|
):
|
||||||
import xe_addons
|
attn_output = scaled_dot_product_attention(
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
query_states, key_states, value_states,
|
||||||
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
attention_mask, True
|
||||||
value_states, attention_mask)
|
)
|
||||||
else:
|
|
||||||
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
|
||||||
value_states, attention_mask)
|
|
||||||
else:
|
else:
|
||||||
if use_quantizekv:
|
if use_quantizekv:
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
|
|
@ -334,17 +326,3 @@ def phi3v_model_forward_wrapper(origin_model_forward):
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
return model_forward
|
return model_forward
|
||||||
|
|
||||||
|
|
||||||
def phi3_rms_norm_forward(self, hidden_states):
|
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
|
||||||
import xe_addons
|
|
||||||
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
|
|
||||||
output = xe_addons.rms_norm(self.weight, x_2d, self.variance_epsilon)
|
|
||||||
return output.reshape(hidden_states.shape)
|
|
||||||
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
return self.weight * hidden_states.to(input_dtype)
|
|
||||||
|
|
|
||||||
|
|
@ -556,9 +556,6 @@ def qwen2_attention_forward(
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask[:, :, :, :kv_seq_len]
|
|
||||||
|
|
||||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||||
|
|
@ -584,6 +581,8 @@ def qwen2_attention_forward(
|
||||||
|
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
if use_flash_attention(query_states, key_states, attention_mask):
|
if use_flash_attention(query_states, key_states, attention_mask):
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, :, :, :kv_seq_len]
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue