add sdp for gemma2 (#11677)
This commit is contained in:
parent
c11d5301d7
commit
6f999e6e90
3 changed files with 87 additions and 23 deletions
|
|
@ -1512,9 +1512,12 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
|
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
|
||||||
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
|
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
|
||||||
|
from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
|
||||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention
|
||||||
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
||||||
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
|
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
|
||||||
convert_forward(model, Gemma2Attention, gemma2_attention_forward)
|
convert_forward(model, Gemma2Attention, gemma2_attention_forward)
|
||||||
|
convert_forward(model, Gemma2Model, gemma2_model_forward)
|
||||||
elif model.config.model_type == "Yi":
|
elif model.config.model_type == "Yi":
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
|
||||||
|
|
@ -31,14 +31,13 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from ipex_llm.utils.common import invalidInputError
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2Attention
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2Attention
|
||||||
from transformers.models.gemma2.modeling_gemma2 import repeat_kv, apply_rotary_pos_emb
|
from transformers.models.gemma2.modeling_gemma2 import repeat_kv, apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -46,6 +45,46 @@ def merge_qkv(module: torch.nn.Module):
|
||||||
return merge_qkv_base(module, Gemma2Attention)
|
return merge_qkv_base(module, Gemma2Attention)
|
||||||
|
|
||||||
|
|
||||||
|
def gemma2_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Cache] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
# ipex-llm change start: add kv_seq_len in past_key_values
|
||||||
|
if past_key_values is not None:
|
||||||
|
if cache_position is not None:
|
||||||
|
kv_seq_len = cache_position[-1].item() + 1
|
||||||
|
else:
|
||||||
|
if input_ids is not None:
|
||||||
|
kv_seq_len = input_ids.size(1)
|
||||||
|
else:
|
||||||
|
kv_seq_len = inputs_embeds.size(1)
|
||||||
|
past_key_values.kv_seq_len = kv_seq_len
|
||||||
|
# ipex-llm change end
|
||||||
|
|
||||||
|
return Gemma2Model.forward(
|
||||||
|
self=self,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def gemma2_attention_forward(
|
def gemma2_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
@ -86,6 +125,28 @@ def gemma2_attention_forward(
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, cache_kwargs)
|
self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# IPEX_LLM OPT: sdp
|
||||||
|
kv_seq_len = q_len if past_key_value is None else past_key_value.kv_seq_len
|
||||||
|
if (use_sdp_causal(q_len, kv_seq_len, -1, query_states, self.training)
|
||||||
|
and kv_seq_len <= key_states.size(2)):
|
||||||
|
import xe_addons
|
||||||
|
attn_weights = None
|
||||||
|
attn_output = xe_addons.gemma2_sdp_causal(query_states,
|
||||||
|
key_states[:, :, :kv_seq_len, :],
|
||||||
|
value_states[:, :, :kv_seq_len, :],
|
||||||
|
attention_mask[:, :, :q_len, :kv_seq_len],
|
||||||
|
self.config.attn_logit_softcapping,
|
||||||
|
self.scaling)
|
||||||
|
elif use_sdp(q_len, kv_seq_len, -1, query_states):
|
||||||
|
import xe_addons
|
||||||
|
attn_weights = None
|
||||||
|
attn_output = xe_addons.gemma2_sdp(query_states,
|
||||||
|
key_states[:, :, :kv_seq_len, :],
|
||||||
|
value_states[:, :, :kv_seq_len, :],
|
||||||
|
attention_mask[:, :, :q_len, :kv_seq_len],
|
||||||
|
self.config.attn_logit_softcapping,
|
||||||
|
self.scaling)
|
||||||
|
else:
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
|
@ -101,10 +162,10 @@ def gemma2_attention_forward(
|
||||||
attn_weights = attn_weights + causal_mask
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = torch.nn.functional.softmax(attn_weights,
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||||
dim=-1, dtype=torch.float32).to(query_states.dtype)
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
attn_weights = torch.nn.functional.dropout(attn_weights,
|
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||||
p=self.attention_dropout, training=self.training)
|
training=self.training)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
|
||||||
|
|
@ -329,7 +329,7 @@ def use_sdp(q_len, kv_len, head_dim, query_states):
|
||||||
return (
|
return (
|
||||||
query_states.device.type == "xpu"
|
query_states.device.type == "xpu"
|
||||||
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
|
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
|
||||||
and head_dim in [64, 80, 96, 128]
|
and head_dim in [-1, 64, 80, 96, 128]
|
||||||
and q_len != kv_len # next token
|
and q_len != kv_len # next token
|
||||||
and q_len <= 32 # lookup
|
and q_len <= 32 # lookup
|
||||||
)
|
)
|
||||||
|
|
@ -347,7 +347,7 @@ def use_sdp_fp8(q_len, kv_len, query_states):
|
||||||
def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
|
def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
|
||||||
return (
|
return (
|
||||||
q_len == kv_len # first token
|
q_len == kv_len # first token
|
||||||
and head_dim in [64, 80, 96, 128] # for now
|
and head_dim in [-1, 64, 80, 96, 128] # for now
|
||||||
and query_states.device.type == "xpu" # GPU
|
and query_states.device.type == "xpu" # GPU
|
||||||
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
|
and query_states.dtype in [torch.float, torch.half] # fp32/fp16
|
||||||
and not query_states.requires_grad and not training # not training
|
and not query_states.requires_grad and not training # not training
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue