[LLM] Use IPEX Optimization for BF16 Model (#9988)
Use IPEX Optimization for BF16 Model by env BIGDL_OPT_IPEX=true
This commit is contained in:
parent
440cfe18ed
commit
f37e4702bc
2 changed files with 332 additions and 0 deletions
|
|
@ -48,6 +48,7 @@ from typing import Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
def is_auto_gptq_available():
|
def is_auto_gptq_available():
|
||||||
|
|
@ -528,6 +529,13 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
# Do nothing here for weights are empty.
|
# Do nothing here for weights are empty.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
||||||
|
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
||||||
|
_enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
|
||||||
|
logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
|
||||||
|
if _enable_ipex:
|
||||||
|
model = _optimize_ipex(model)
|
||||||
|
return model
|
||||||
if optimize_model:
|
if optimize_model:
|
||||||
model = _optimize_post(model, lightweight_bmm)
|
model = _optimize_post(model, lightweight_bmm)
|
||||||
return model
|
return model
|
||||||
|
|
@ -560,6 +568,28 @@ def replace_func(m, target_m, func_name, new_func):
|
||||||
replace_func(sub_m, target_m, func_name, new_func)
|
replace_func(sub_m, target_m, func_name, new_func)
|
||||||
|
|
||||||
|
|
||||||
|
def _optimize_ipex(model):
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
from intel_extension_for_pytorch.transformers.optimize import model_convert_reference
|
||||||
|
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
|
||||||
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
|
from bigdl.llm.transformers.convert_ipex import (
|
||||||
|
_ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
|
||||||
|
_llama_model_forward_4_35
|
||||||
|
)
|
||||||
|
|
||||||
|
AttentionMaskConverter._make_causal_mask = _make_causal_mask
|
||||||
|
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, _llama_model_forward_4_35) # noqa
|
||||||
|
model = model_convert_reference(model)
|
||||||
|
|
||||||
|
_ipex_optimize_attention(model, transformers.models.llama.modeling_llama.LlamaAttention)
|
||||||
|
_ipex_optimize_decoder(model, transformers.models.llama.modeling_llama.LlamaDecoderLayer)
|
||||||
|
|
||||||
|
model.register_forward_hook(output_hook, with_kwargs=True)
|
||||||
|
|
||||||
|
return _ipex_jit(model)
|
||||||
|
|
||||||
|
|
||||||
def _optimize_post(model, lightweight_bmm=False):
|
def _optimize_post(model, lightweight_bmm=False):
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
||||||
|
|
|
||||||
302
python/llm/src/bigdl/llm/transformers/convert_ipex.py
Normal file
302
python/llm/src/bigdl/llm/transformers/convert_ipex.py
Normal file
|
|
@ -0,0 +1,302 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
# Some parts of this file is adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/llama/modeling_llama.py # noqa
|
||||||
|
# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py
|
||||||
|
# which is licensed under Apache License 2.0:
|
||||||
|
#
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
def lowering_class_cpu(m, target_m, new_class, config, tpp=False, woq=False):
|
||||||
|
for name, sub_m in m.named_children():
|
||||||
|
if isinstance(sub_m, target_m):
|
||||||
|
new_m = new_class(sub_m, config, tpp, woq)
|
||||||
|
setattr(m, name, new_m)
|
||||||
|
lowering_class_cpu(sub_m, target_m, new_class, config, tpp, woq)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_class(m, target_m, new_class, config, distributed=False):
|
||||||
|
for name, sub_m in m.named_children():
|
||||||
|
if isinstance(sub_m, target_m):
|
||||||
|
new_m = new_class(sub_m, config, distributed)
|
||||||
|
setattr(m, name, new_m)
|
||||||
|
convert_class(sub_m, target_m, new_class, config, distributed)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_optimized_model_for_generation(
|
||||||
|
model,
|
||||||
|
optimized_model,
|
||||||
|
first_token_optimized_model=None,
|
||||||
|
):
|
||||||
|
from intel_extension_for_pytorch.transformers.models.reference.models import (
|
||||||
|
IPEX_LLM_Model_Return
|
||||||
|
)
|
||||||
|
if first_token_optimized_model is not None:
|
||||||
|
model.trace_graph_first = IPEX_LLM_Model_Return(
|
||||||
|
model, first_token_optimized_model
|
||||||
|
).forward
|
||||||
|
|
||||||
|
model.trace_graph = IPEX_LLM_Model_Return(model, optimized_model).forward
|
||||||
|
print(
|
||||||
|
"ipex.llm.optimize has set the optimized or quantization model for model.generate()"
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _ipex_optimize_decoder(model, decoder_layer):
|
||||||
|
from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
|
||||||
|
_IPEXDecoderLayerRef
|
||||||
|
)
|
||||||
|
from intel_extension_for_pytorch.transformers.models.cpu.modules.decoder import (
|
||||||
|
_IPEXDecoderLayerCPU
|
||||||
|
)
|
||||||
|
for supported_mlp_class in [_IPEXDecoderLayerRef]:
|
||||||
|
lowering_class_cpu(
|
||||||
|
model,
|
||||||
|
supported_mlp_class,
|
||||||
|
_IPEXDecoderLayerCPU,
|
||||||
|
model.config,
|
||||||
|
tpp=False,
|
||||||
|
woq=False,
|
||||||
|
)
|
||||||
|
convert_class(
|
||||||
|
model,
|
||||||
|
decoder_layer,
|
||||||
|
_IPEXDecoderLayerRef,
|
||||||
|
model.config,
|
||||||
|
distributed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ipex_optimize_attention(model, attention_layer):
|
||||||
|
from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
|
||||||
|
_IPEXAttentionRef
|
||||||
|
)
|
||||||
|
from intel_extension_for_pytorch.transformers.models.cpu.modules.attentions import (
|
||||||
|
_IPEXAttentionCPU
|
||||||
|
)
|
||||||
|
for supported_mha_class in [_IPEXAttentionRef]:
|
||||||
|
lowering_class_cpu(
|
||||||
|
model,
|
||||||
|
supported_mha_class,
|
||||||
|
_IPEXAttentionCPU,
|
||||||
|
model.config,
|
||||||
|
tpp=False,
|
||||||
|
woq=False,
|
||||||
|
)
|
||||||
|
convert_class(
|
||||||
|
model,
|
||||||
|
attention_layer,
|
||||||
|
_IPEXAttentionRef,
|
||||||
|
model.config,
|
||||||
|
distributed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ipex_jit(model):
|
||||||
|
from intel_extension_for_pytorch.transformers.optimize import get_dummy_input
|
||||||
|
sample_inputs = (
|
||||||
|
get_dummy_input(model, return_dict=True)
|
||||||
|
)
|
||||||
|
with torch.no_grad(), torch.cpu.amp.autocast(
|
||||||
|
enabled=True
|
||||||
|
):
|
||||||
|
trace_model = torch.jit.trace(
|
||||||
|
model,
|
||||||
|
example_kwarg_inputs=sample_inputs,
|
||||||
|
strict=False,
|
||||||
|
check_trace=False,
|
||||||
|
)
|
||||||
|
trace_model = torch.jit.freeze(trace_model)
|
||||||
|
model = _set_optimized_model_for_generation(
|
||||||
|
model, optimized_model=trace_model
|
||||||
|
)
|
||||||
|
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_causal_mask(
|
||||||
|
input_ids_shape: torch.Size,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
past_key_values_length: int = 0,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Make causal mask used for bi-directional self-attention.
|
||||||
|
"""
|
||||||
|
bsz, tgt_len = input_ids_shape
|
||||||
|
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
||||||
|
mask_cond = torch.arange(mask.size(-1), device=device)
|
||||||
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
|
|
||||||
|
import os
|
||||||
|
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
||||||
|
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
||||||
|
if _enable_ipex or past_key_values_length > 0:
|
||||||
|
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) # noqa
|
||||||
|
|
||||||
|
# add lower triangular sliding window mask if necessary
|
||||||
|
if sliding_window is not None:
|
||||||
|
diagonal = past_key_values_length - sliding_window + 1
|
||||||
|
|
||||||
|
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
|
||||||
|
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
||||||
|
|
||||||
|
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def _llama_model_forward_4_35(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") # noqa
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape[:2]
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds") # noqa
|
||||||
|
|
||||||
|
past_key_values_length = 0
|
||||||
|
if past_key_values is not None:
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device # noqa
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||||
|
# 2d mask is passed through the layers
|
||||||
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None # noqa
|
||||||
|
else:
|
||||||
|
# 4d mask is passed through the layers
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." # noqa
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) # noqa
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
Loading…
Reference in a new issue