[LLM] Use IPEX Optimization for BF16 Model (#9988)

Use IPEX Optimization for BF16 Model by env BIGDL_OPT_IPEX=true
This commit is contained in:
Xiangyu Tian 2024-01-29 11:28:25 +08:00 committed by GitHub
parent 440cfe18ed
commit f37e4702bc
2 changed files with 332 additions and 0 deletions

View file

@ -48,6 +48,7 @@ from typing import Union
import numpy as np
import os
from bigdl.llm.utils.common import invalidInputError
from typing import List, Optional, Tuple, Union
def is_auto_gptq_available():
@ -528,6 +529,13 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
# Do nothing here for weights are empty.
pass
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
_enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
if _enable_ipex:
model = _optimize_ipex(model)
return model
if optimize_model:
model = _optimize_post(model, lightweight_bmm)
return model
@ -560,6 +568,28 @@ def replace_func(m, target_m, func_name, new_func):
replace_func(sub_m, target_m, func_name, new_func)
def _optimize_ipex(model):
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.transformers.optimize import model_convert_reference
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from bigdl.llm.transformers.convert_ipex import (
_ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
_llama_model_forward_4_35
)
AttentionMaskConverter._make_causal_mask = _make_causal_mask
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, _llama_model_forward_4_35) # noqa
model = model_convert_reference(model)
_ipex_optimize_attention(model, transformers.models.llama.modeling_llama.LlamaAttention)
_ipex_optimize_decoder(model, transformers.models.llama.modeling_llama.LlamaDecoderLayer)
model.register_forward_hook(output_hook, with_kwargs=True)
return _ipex_jit(model)
def _optimize_post(model, lightweight_bmm=False):
from packaging import version
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31

View file

@ -0,0 +1,302 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/llama/modeling_llama.py # noqa
# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py
# which is licensed under Apache License 2.0:
#
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from bigdl.llm.utils.common import invalidInputError
from typing import List, Optional, Tuple, Union
def lowering_class_cpu(m, target_m, new_class, config, tpp=False, woq=False):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
new_m = new_class(sub_m, config, tpp, woq)
setattr(m, name, new_m)
lowering_class_cpu(sub_m, target_m, new_class, config, tpp, woq)
def convert_class(m, target_m, new_class, config, distributed=False):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
new_m = new_class(sub_m, config, distributed)
setattr(m, name, new_m)
convert_class(sub_m, target_m, new_class, config, distributed)
def _set_optimized_model_for_generation(
model,
optimized_model,
first_token_optimized_model=None,
):
from intel_extension_for_pytorch.transformers.models.reference.models import (
IPEX_LLM_Model_Return
)
if first_token_optimized_model is not None:
model.trace_graph_first = IPEX_LLM_Model_Return(
model, first_token_optimized_model
).forward
model.trace_graph = IPEX_LLM_Model_Return(model, optimized_model).forward
print(
"ipex.llm.optimize has set the optimized or quantization model for model.generate()"
)
return model
def _ipex_optimize_decoder(model, decoder_layer):
from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
_IPEXDecoderLayerRef
)
from intel_extension_for_pytorch.transformers.models.cpu.modules.decoder import (
_IPEXDecoderLayerCPU
)
for supported_mlp_class in [_IPEXDecoderLayerRef]:
lowering_class_cpu(
model,
supported_mlp_class,
_IPEXDecoderLayerCPU,
model.config,
tpp=False,
woq=False,
)
convert_class(
model,
decoder_layer,
_IPEXDecoderLayerRef,
model.config,
distributed=True,
)
def _ipex_optimize_attention(model, attention_layer):
from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
_IPEXAttentionRef
)
from intel_extension_for_pytorch.transformers.models.cpu.modules.attentions import (
_IPEXAttentionCPU
)
for supported_mha_class in [_IPEXAttentionRef]:
lowering_class_cpu(
model,
supported_mha_class,
_IPEXAttentionCPU,
model.config,
tpp=False,
woq=False,
)
convert_class(
model,
attention_layer,
_IPEXAttentionRef,
model.config,
distributed=True,
)
def _ipex_jit(model):
from intel_extension_for_pytorch.transformers.optimize import get_dummy_input
sample_inputs = (
get_dummy_input(model, return_dict=True)
)
with torch.no_grad(), torch.cpu.amp.autocast(
enabled=True
):
trace_model = torch.jit.trace(
model,
example_kwarg_inputs=sample_inputs,
strict=False,
check_trace=False,
)
trace_model = torch.jit.freeze(trace_model)
model = _set_optimized_model_for_generation(
model, optimized_model=trace_model
)
return model.eval()
@staticmethod
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
import os
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
if _enable_ipex or past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) # noqa
# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window + 1
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
def _llama_model_forward_4_35(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") # noqa
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds") # noqa
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device # noqa
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if getattr(self.config, "_flash_attn_2_enabled", False):
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None # noqa
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# embed positions
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." # noqa
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) # noqa
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)