382 lines
14 KiB
Python
382 lines
14 KiB
Python
#
|
|
# Copyright 2016 The BigDL Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
|
|
# Some parts of this file is adapted from
|
|
# https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/llama/modeling_llama.py # noqa
|
|
# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py
|
|
# which is licensed under Apache License 2.0:
|
|
#
|
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
from ipex_llm.utils.common import invalidInputError
|
|
from typing import List, Optional, Tuple, Union
|
|
from intel_extension_for_pytorch.transformers.optimize import (
|
|
lowering_class_cpu,
|
|
convert_class,
|
|
)
|
|
from intel_extension_for_pytorch.cpu._auto_kernel_selection import (
|
|
_enable_tpp,
|
|
_using_tpp,
|
|
_disable_tpp
|
|
)
|
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
|
from ipex_llm.transformers.convert import get_enable_ipex
|
|
import os
|
|
|
|
|
|
def _ipex_optimize_rmsnorm(_model, supported_classes, is_tpp=False, is_woq=False):
|
|
try:
|
|
# old version use name `_IPEXRMSNorm`
|
|
from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion \
|
|
import _IPEXRMSNorm
|
|
except ImportError:
|
|
# new version use name `_IPEXRMSNormCPU`
|
|
from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion \
|
|
import _IPEXRMSNormCPU as _IPEXRMSNorm
|
|
for supported_class in supported_classes:
|
|
lowering_class_cpu(
|
|
_model,
|
|
supported_class,
|
|
_IPEXRMSNorm,
|
|
_model.config,
|
|
tpp=is_tpp,
|
|
woq=is_woq,
|
|
)
|
|
|
|
|
|
def _ipex_optimize_decoder(model, is_tpp=False, is_woq=False):
|
|
from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
|
|
_IPEXDecoderLayerRef
|
|
)
|
|
from intel_extension_for_pytorch.transformers.models.cpu.modules.decoder import (
|
|
_IPEXDecoderLayerCPU
|
|
)
|
|
for supported_mlp_class in [_IPEXDecoderLayerRef]:
|
|
lowering_class_cpu(
|
|
model,
|
|
supported_mlp_class,
|
|
_IPEXDecoderLayerCPU,
|
|
model.config,
|
|
tpp=is_tpp,
|
|
woq=is_woq,
|
|
)
|
|
|
|
|
|
def _ipex_optimize_attention(model, is_tpp=False, is_woq=False):
|
|
from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
|
|
_IPEXAttentionRef
|
|
)
|
|
from intel_extension_for_pytorch.transformers.models.cpu.modules.attentions import (
|
|
_IPEXAttentionCPU
|
|
)
|
|
for supported_mha_class in [_IPEXAttentionRef]:
|
|
lowering_class_cpu(
|
|
model,
|
|
supported_mha_class,
|
|
_IPEXAttentionCPU,
|
|
model.config,
|
|
tpp=is_tpp,
|
|
woq=is_woq,
|
|
)
|
|
|
|
|
|
def _ipex_optimize_model(model, rms_classes, qtype):
|
|
import intel_extension_for_pytorch as ipex
|
|
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
|
|
from intel_extension_for_pytorch.transformers.optimize import ipex_quantization_flow
|
|
|
|
is_woq = False
|
|
is_quantization = False
|
|
_disable_tpp()
|
|
if qtype == ggml_tensor_qtype["bf16"]:
|
|
_enable_tpp()
|
|
model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
|
|
elif qtype == ggml_tensor_qtype["sym_int4"]:
|
|
is_quantization = True
|
|
is_woq = True
|
|
act_quant_mode_dict = {
|
|
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
|
|
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
|
|
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
|
|
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
|
|
}
|
|
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
|
|
weight_dtype=torch.quint4x2, # INT4
|
|
lowp_mode=ipex.quantization.WoqLowpMode.INT8,
|
|
act_quant_mode=act_quant_mode_dict["PER_IC_BLOCK"],
|
|
group_size=-1,
|
|
)
|
|
model = ipex_quantization_flow(model, torch.bfloat16, None, qconfig, None)
|
|
elif qtype == ggml_tensor_qtype["sym_int8"]:
|
|
is_quantization = True
|
|
is_woq = True
|
|
act_quant_mode_dict = {
|
|
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
|
|
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
|
|
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
|
|
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
|
|
}
|
|
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
|
|
weight_dtype=torch.qint8, # INT8
|
|
lowp_mode=ipex.quantization.WoqLowpMode.BF16,
|
|
act_quant_mode=act_quant_mode_dict["PER_IC_BLOCK"],
|
|
group_size=-1,
|
|
)
|
|
model = ipex_quantization_flow(model, torch.bfloat16, None, qconfig, None)
|
|
|
|
is_tpp = _using_tpp()
|
|
|
|
_ipex_optimize_rmsnorm(model, rms_classes, is_tpp=is_tpp, is_woq=is_woq)
|
|
_ipex_optimize_attention(model, is_tpp=is_tpp, is_woq=is_woq)
|
|
_ipex_optimize_decoder(model, is_tpp=is_tpp, is_woq=is_woq)
|
|
|
|
# need to register_forward_hook after torch.jit.trace
|
|
# model.register_forward_hook(output_hook, with_kwargs=True)
|
|
return model
|
|
|
|
|
|
def _ipex_jit(model):
|
|
from intel_extension_for_pytorch.transformers.optimize import (
|
|
get_dummy_input,
|
|
_set_optimized_model_for_generation
|
|
)
|
|
sample_inputs = (
|
|
get_dummy_input(model, return_dict=True)
|
|
)
|
|
if "return_last_logit" in sample_inputs:
|
|
sample_inputs["return_last_logit"] = torch.tensor(False)
|
|
with torch.no_grad(), torch.cpu.amp.autocast(
|
|
enabled=True
|
|
):
|
|
trace_model = torch.jit.trace(
|
|
model,
|
|
example_kwarg_inputs=sample_inputs,
|
|
strict=False,
|
|
check_trace=False,
|
|
)
|
|
trace_model = torch.jit.freeze(trace_model)
|
|
model = _set_optimized_model_for_generation(
|
|
model, optimized_model=trace_model
|
|
)
|
|
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
|
|
model.register_forward_hook(output_hook, with_kwargs=True)
|
|
|
|
return model
|
|
|
|
|
|
def convert_function(m, func_name, new_function):
|
|
bound_method = new_function.__get__(m, m.__class__)
|
|
setattr(m, func_name, bound_method)
|
|
|
|
|
|
def GLM_get_masks(self, input_ids, past_key_values, padding_mask=None):
|
|
batch_size, seq_length = input_ids.shape
|
|
full_attention_mask = torch.ones(
|
|
batch_size, seq_length, seq_length, device=input_ids.device
|
|
)
|
|
full_attention_mask.tril_()
|
|
past_length = 0
|
|
if past_key_values:
|
|
if len(past_key_values[0]) != 4: # not discrete kv cache
|
|
past_length = past_key_values[0][0].shape[0]
|
|
else: # discrete kv cache
|
|
past_length = past_key_values[0][0].shape[-2]
|
|
|
|
_enable_ipex = get_enable_ipex()
|
|
# always call for jit
|
|
if past_length or _enable_ipex:
|
|
full_attention_mask = torch.cat(
|
|
(
|
|
torch.ones(
|
|
batch_size, seq_length, past_length, device=input_ids.device
|
|
),
|
|
full_attention_mask,
|
|
),
|
|
dim=-1,
|
|
)
|
|
if padding_mask is not None:
|
|
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
|
# if not past_length and padding_mask is not None:
|
|
# full_attention_mask -= padding_mask.unsqueeze(-1) - 1
|
|
full_attention_mask = (full_attention_mask < 0.5).bool()
|
|
full_attention_mask.unsqueeze_(1)
|
|
return full_attention_mask
|
|
|
|
|
|
@staticmethod
|
|
def _make_causal_mask(
|
|
input_ids_shape: torch.Size,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
past_key_values_length: int = 0,
|
|
sliding_window: Optional[int] = None,
|
|
):
|
|
"""
|
|
Make causal mask used for bi-directional self-attention.
|
|
"""
|
|
bsz, tgt_len = input_ids_shape
|
|
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
|
mask_cond = torch.arange(mask.size(-1), device=device)
|
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
|
|
|
_enable_ipex = get_enable_ipex()
|
|
if _enable_ipex or past_key_values_length > 0:
|
|
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) # noqa
|
|
|
|
# add lower triangular sliding window mask if necessary
|
|
if sliding_window is not None:
|
|
diagonal = past_key_values_length - sliding_window + 1
|
|
|
|
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
|
|
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
|
|
|
|
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
|
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
|
|
|
|
|
def _llama_model_forward_4_35(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
# retrieve input_ids and inputs_embeds
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") # noqa
|
|
elif input_ids is not None:
|
|
batch_size, seq_length = input_ids.shape[:2]
|
|
elif inputs_embeds is not None:
|
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds") # noqa
|
|
|
|
past_key_values_length = 0
|
|
if past_key_values is not None:
|
|
past_key_values_length = past_key_values[0][0].shape[2]
|
|
|
|
if position_ids is None:
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
position_ids = torch.arange(
|
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device # noqa
|
|
)
|
|
position_ids = position_ids.unsqueeze(0)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
if getattr(self.config, "_flash_attn_2_enabled", False):
|
|
# 2d mask is passed through the layers
|
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None # noqa
|
|
else:
|
|
# 4d mask is passed through the layers
|
|
attention_mask = _prepare_4d_causal_attention_mask(
|
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
|
)
|
|
|
|
# embed positions
|
|
hidden_states = inputs_embeds
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
if use_cache:
|
|
logger.warning_once(
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." # noqa
|
|
)
|
|
use_cache = False
|
|
|
|
# decoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
next_decoder_cache = () if use_cache else None
|
|
|
|
for idx, decoder_layer in enumerate(self.layers):
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
layer_outputs = self._gradient_checkpointing_func(
|
|
decoder_layer.__call__,
|
|
hidden_states,
|
|
attention_mask,
|
|
position_ids,
|
|
past_key_value,
|
|
output_attentions,
|
|
use_cache,
|
|
)
|
|
else:
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if use_cache:
|
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
# add hidden states from the last decoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
next_cache = next_decoder_cache if use_cache else None
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) # noqa
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=next_cache,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
)
|