diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index ca26f416..41b66bff 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -48,6 +48,7 @@ from typing import Union import numpy as np import os from bigdl.llm.utils.common import invalidInputError +from typing import List, Optional, Tuple, Union def is_auto_gptq_available(): @@ -528,6 +529,13 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, # Do nothing here for weights are empty. pass + _enable_ipex = os.getenv("BIGDL_OPT_IPEX") + _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") + _enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"]) + logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}") + if _enable_ipex: + model = _optimize_ipex(model) + return model if optimize_model: model = _optimize_post(model, lightweight_bmm) return model @@ -560,6 +568,28 @@ def replace_func(m, target_m, func_name, new_func): replace_func(sub_m, target_m, func_name, new_func) +def _optimize_ipex(model): + import intel_extension_for_pytorch as ipex + from intel_extension_for_pytorch.transformers.optimize import model_convert_reference + from intel_extension_for_pytorch.transformers.models.reference.models import output_hook + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + from bigdl.llm.transformers.convert_ipex import ( + _ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask, + _llama_model_forward_4_35 + ) + + AttentionMaskConverter._make_causal_mask = _make_causal_mask + convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, _llama_model_forward_4_35) # noqa + model = model_convert_reference(model) + + _ipex_optimize_attention(model, transformers.models.llama.modeling_llama.LlamaAttention) + _ipex_optimize_decoder(model, transformers.models.llama.modeling_llama.LlamaDecoderLayer) + + model.register_forward_hook(output_hook, with_kwargs=True) + + return _ipex_jit(model) + + def _optimize_post(model, lightweight_bmm=False): from packaging import version from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 diff --git a/python/llm/src/bigdl/llm/transformers/convert_ipex.py b/python/llm/src/bigdl/llm/transformers/convert_ipex.py new file mode 100644 index 00000000..b185aba0 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/convert_ipex.py @@ -0,0 +1,302 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/llama/modeling_llama.py # noqa +# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py +# which is licensed under Apache License 2.0: +# +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from bigdl.llm.utils.common import invalidInputError +from typing import List, Optional, Tuple, Union + + +def lowering_class_cpu(m, target_m, new_class, config, tpp=False, woq=False): + for name, sub_m in m.named_children(): + if isinstance(sub_m, target_m): + new_m = new_class(sub_m, config, tpp, woq) + setattr(m, name, new_m) + lowering_class_cpu(sub_m, target_m, new_class, config, tpp, woq) + + +def convert_class(m, target_m, new_class, config, distributed=False): + for name, sub_m in m.named_children(): + if isinstance(sub_m, target_m): + new_m = new_class(sub_m, config, distributed) + setattr(m, name, new_m) + convert_class(sub_m, target_m, new_class, config, distributed) + + +def _set_optimized_model_for_generation( + model, + optimized_model, + first_token_optimized_model=None, +): + from intel_extension_for_pytorch.transformers.models.reference.models import ( + IPEX_LLM_Model_Return + ) + if first_token_optimized_model is not None: + model.trace_graph_first = IPEX_LLM_Model_Return( + model, first_token_optimized_model + ).forward + + model.trace_graph = IPEX_LLM_Model_Return(model, optimized_model).forward + print( + "ipex.llm.optimize has set the optimized or quantization model for model.generate()" + ) + return model + + +def _ipex_optimize_decoder(model, decoder_layer): + from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import ( + _IPEXDecoderLayerRef + ) + from intel_extension_for_pytorch.transformers.models.cpu.modules.decoder import ( + _IPEXDecoderLayerCPU + ) + for supported_mlp_class in [_IPEXDecoderLayerRef]: + lowering_class_cpu( + model, + supported_mlp_class, + _IPEXDecoderLayerCPU, + model.config, + tpp=False, + woq=False, + ) + convert_class( + model, + decoder_layer, + _IPEXDecoderLayerRef, + model.config, + distributed=True, + ) + + +def _ipex_optimize_attention(model, attention_layer): + from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import ( + _IPEXAttentionRef + ) + from intel_extension_for_pytorch.transformers.models.cpu.modules.attentions import ( + _IPEXAttentionCPU + ) + for supported_mha_class in [_IPEXAttentionRef]: + lowering_class_cpu( + model, + supported_mha_class, + _IPEXAttentionCPU, + model.config, + tpp=False, + woq=False, + ) + convert_class( + model, + attention_layer, + _IPEXAttentionRef, + model.config, + distributed=True, + ) + + +def _ipex_jit(model): + from intel_extension_for_pytorch.transformers.optimize import get_dummy_input + sample_inputs = ( + get_dummy_input(model, return_dict=True) + ) + with torch.no_grad(), torch.cpu.amp.autocast( + enabled=True + ): + trace_model = torch.jit.trace( + model, + example_kwarg_inputs=sample_inputs, + strict=False, + check_trace=False, + ) + trace_model = torch.jit.freeze(trace_model) + model = _set_optimized_model_for_generation( + model, optimized_model=trace_model + ) + + return model.eval() + + +@staticmethod +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + import os + _enable_ipex = os.getenv("BIGDL_OPT_IPEX") + _enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true") + if _enable_ipex or past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) # noqa + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + + +def _llama_model_forward_4_35( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # noqa + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") # noqa + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") # noqa + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device # noqa + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None # noqa + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." # noqa + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) # noqa + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + )