[LLM] Use IPEX Optimization for BF16 Model (#9988)

Use IPEX Optimization for BF16 Model by env BIGDL_OPT_IPEX=true
This commit is contained in:
Xiangyu Tian 2024-01-29 11:28:25 +08:00 committed by GitHub
parent 440cfe18ed
commit f37e4702bc
2 changed files with 332 additions and 0 deletions

View file

@ -48,6 +48,7 @@ from typing import Union
import numpy as np import numpy as np
import os import os
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from typing import List, Optional, Tuple, Union
def is_auto_gptq_available(): def is_auto_gptq_available():
@ -528,6 +529,13 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
# Do nothing here for weights are empty. # Do nothing here for weights are empty.
pass pass
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
_enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
if _enable_ipex:
model = _optimize_ipex(model)
return model
if optimize_model: if optimize_model:
model = _optimize_post(model, lightweight_bmm) model = _optimize_post(model, lightweight_bmm)
return model return model
@ -560,6 +568,28 @@ def replace_func(m, target_m, func_name, new_func):
replace_func(sub_m, target_m, func_name, new_func) replace_func(sub_m, target_m, func_name, new_func)
def _optimize_ipex(model):
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.transformers.optimize import model_convert_reference
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from bigdl.llm.transformers.convert_ipex import (
_ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
_llama_model_forward_4_35
)
AttentionMaskConverter._make_causal_mask = _make_causal_mask
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, _llama_model_forward_4_35) # noqa
model = model_convert_reference(model)
_ipex_optimize_attention(model, transformers.models.llama.modeling_llama.LlamaAttention)
_ipex_optimize_decoder(model, transformers.models.llama.modeling_llama.LlamaDecoderLayer)
model.register_forward_hook(output_hook, with_kwargs=True)
return _ipex_jit(model)
def _optimize_post(model, lightweight_bmm=False): def _optimize_post(model, lightweight_bmm=False):
from packaging import version from packaging import version
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31

View file

@ -0,0 +1,302 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/llama/modeling_llama.py # noqa
# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py
# which is licensed under Apache License 2.0:
#
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from bigdl.llm.utils.common import invalidInputError
from typing import List, Optional, Tuple, Union
def lowering_class_cpu(m, target_m, new_class, config, tpp=False, woq=False):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
new_m = new_class(sub_m, config, tpp, woq)
setattr(m, name, new_m)
lowering_class_cpu(sub_m, target_m, new_class, config, tpp, woq)
def convert_class(m, target_m, new_class, config, distributed=False):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
new_m = new_class(sub_m, config, distributed)
setattr(m, name, new_m)
convert_class(sub_m, target_m, new_class, config, distributed)
def _set_optimized_model_for_generation(
model,
optimized_model,
first_token_optimized_model=None,
):
from intel_extension_for_pytorch.transformers.models.reference.models import (
IPEX_LLM_Model_Return
)
if first_token_optimized_model is not None:
model.trace_graph_first = IPEX_LLM_Model_Return(
model, first_token_optimized_model
).forward
model.trace_graph = IPEX_LLM_Model_Return(model, optimized_model).forward
print(
"ipex.llm.optimize has set the optimized or quantization model for model.generate()"
)
return model
def _ipex_optimize_decoder(model, decoder_layer):
from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
_IPEXDecoderLayerRef
)
from intel_extension_for_pytorch.transformers.models.cpu.modules.decoder import (
_IPEXDecoderLayerCPU
)
for supported_mlp_class in [_IPEXDecoderLayerRef]:
lowering_class_cpu(
model,
supported_mlp_class,
_IPEXDecoderLayerCPU,
model.config,
tpp=False,
woq=False,
)
convert_class(
model,
decoder_layer,
_IPEXDecoderLayerRef,
model.config,
distributed=True,
)
def _ipex_optimize_attention(model, attention_layer):
from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
_IPEXAttentionRef
)
from intel_extension_for_pytorch.transformers.models.cpu.modules.attentions import (
_IPEXAttentionCPU
)
for supported_mha_class in [_IPEXAttentionRef]:
lowering_class_cpu(
model,
supported_mha_class,
_IPEXAttentionCPU,
model.config,
tpp=False,
woq=False,
)
convert_class(
model,
attention_layer,
_IPEXAttentionRef,
model.config,
distributed=True,
)
def _ipex_jit(model):
from intel_extension_for_pytorch.transformers.optimize import get_dummy_input
sample_inputs = (
get_dummy_input(model, return_dict=True)
)
with torch.no_grad(), torch.cpu.amp.autocast(
enabled=True
):
trace_model = torch.jit.trace(
model,
example_kwarg_inputs=sample_inputs,
strict=False,
check_trace=False,
)
trace_model = torch.jit.freeze(trace_model)
model = _set_optimized_model_for_generation(
model, optimized_model=trace_model
)
return model.eval()
@staticmethod
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
import os
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
if _enable_ipex or past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) # noqa
# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window + 1
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
def _llama_model_forward_4_35(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") # noqa
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds") # noqa
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device # noqa
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if getattr(self.config, "_flash_attn_2_enabled", False):
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None # noqa
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# embed positions
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." # noqa
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) # noqa
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)