diff --git a/python/llm/src/bigdl/llm/optimize.py b/python/llm/src/bigdl/llm/optimize.py index 6396f7fa..39b750f7 100644 --- a/python/llm/src/bigdl/llm/optimize.py +++ b/python/llm/src/bigdl/llm/optimize.py @@ -26,6 +26,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.utils import extract_local_archive_file, get_local_shard_files import transformers +import warnings from transformers import PreTrainedModel from .utils.common import MuteHFLogger from .utils.lazy_load_torch import LazyLoadTensors @@ -193,7 +194,7 @@ def load_low_bit(model, model_path): def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_convert=None, - replace_embedding=False): + cpu_embedding=False, lightweight_bmm=False, **kwargs): """ A method to optimize any pytorch model. @@ -203,7 +204,9 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_ :param optimize_llm: Whether to further optimize llm model. :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when conducting model optimizations. Default to be None. - :param replace_embedding: Whether to replace the Embedding layer, may need to set it + :param cpu_embedding: Whether to replace the Embedding layer, may need to set it + to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`. + :param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`. :return: The optimized model. @@ -226,12 +229,17 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_ invalidInputError(model.device.type == 'cpu', "Expect model on device `cpu`, " f"but got device type {model.device.type}") + if kwargs.pop("replace_embedding", False): + warnings.warn("replace_embedding is deprecated and will be removed in a future version," + " please use cpu_embedding instead.", FutureWarning) + cpu_embedding = True qtype = ggml_tensor_qtype[low_bit] model = ggml_convert_low_bit(model, qtype=qtype, optimize_model=optimize_llm, modules_to_not_convert=modules_to_not_convert, - replace_embedding=replace_embedding) + cpu_embedding=cpu_embedding, + lightweight_bmm=lightweight_bmm) # add save_low_bit to pretrained model dynamically import types model._bigdl_config = dict() diff --git a/python/llm/src/bigdl/llm/transformers/bmm.py b/python/llm/src/bigdl/llm/transformers/bmm.py new file mode 100644 index 00000000..56a12eef --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/bmm.py @@ -0,0 +1,45 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import torch +import linear_q4_0 +torch_bmm_old_ = torch.bmm + + +def torch_bmm(a, b): + if a.device.type == 'cpu': + return torch_bmm_old_(a, b) + + batch, A_rows, common = a.size() + B_cols = b.size(2) + C = torch.empty((batch, A_rows, B_cols), device=a.device) + if a.size(1) == 1: + torch_bmm_old_(a, b, out=C) + else: + linear_q4_0.bmm(a.contiguous(), b.contiguous(), C) + return C + + +class SafeBMM: + def __init__(self): + self._old_bmm = torch_bmm_old_ + + def __enter__(self): + torch.bmm = torch_bmm + + def __exit__(self, *args, **kwargs): + torch.bmm = self._old_bmm diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index b12f4db5..e0082783 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -172,7 +172,7 @@ def convert_gptq(module, awq=False): def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name=None, convert_shape_only=False, - replace_embedding=False): + cpu_embedding=False): from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear from bigdl.llm.transformers.embedding import LLMEmbedding has_been_replaced = False @@ -265,7 +265,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, model._modules[name].requires_grad_(False) module.weight = None - elif replace_embedding and type(module) == nn.Embedding: + elif cpu_embedding and type(module) == nn.Embedding: # skip user-defined Embedding layer if platform.system().lower() == 'windows': model._modules[name] = LLMEmbedding( @@ -287,7 +287,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, modules_to_not_convert, current_key_name, convert_shape_only, - replace_embedding, + cpu_embedding, ) has_been_replaced = _flag or has_been_replaced return model, has_been_replaced @@ -321,7 +321,8 @@ def _optimize_pre(model): def ggml_convert_low_bit(model, qtype, optimize_model=True, convert_shape_only=False, device="cpu", - modules_to_not_convert=None, replace_embedding=False): + modules_to_not_convert=None, cpu_embedding=False, + lightweight_bmm=False): logger.info(f"Converting the current model to " f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} " f"format......") @@ -332,7 +333,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, model, has_been_replaced = _replace_with_low_bit_linear( model, qtype, modules_to_not_convert, - None, convert_shape_only, replace_embedding, + None, convert_shape_only, cpu_embedding, ) if not has_been_replaced: warnings.warn( @@ -349,7 +350,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, pass if optimize_model: - model = _optimize_post(model) + model = _optimize_post(model, lightweight_bmm) return model @@ -361,7 +362,7 @@ def convert_forward(m, target_m, new_forward): convert_forward(sub_m, target_m, new_forward) -def _optimize_post(model): +def _optimize_post(model, lightweight_bmm=False): from packaging import version from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 from bigdl.llm.transformers.models.llama import llama_rms_norm_forward @@ -593,4 +594,18 @@ def _optimize_post(model): convert_forward(model, module.MistralRMSNorm, llama_rms_norm_forward) + elif model.config.model_type == "whisper" and lightweight_bmm: + if platform.system().lower() == 'windows': + from bigdl.llm.transformers.bmm import SafeBMM + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + old_fwd = module.WhisperAttention.forward + + def safe_bmm_fwd(*args, **kwargs): + with SafeBMM(): + return old_fwd(*args, **kwargs) + + convert_forward(model, + module.WhisperAttention, + safe_bmm_fwd) return model diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index d0e40444..5f0bdfb2 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -98,7 +98,9 @@ class _BaseAutoModelClass: Default to be True. :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when conducting model optimizations. Default to be None. - :param replace_embedding: Whether to replace the Embedding layer, may need to set it + :param cpu_embedding: Whether to replace the Embedding layer, may need to set it + to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`. + :param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`. :return: a model instance """ @@ -201,7 +203,12 @@ class _BaseAutoModelClass: # `from_pretrained`` may pop items out in dict # and lead to args missing. modules_to_not_convert = kwargs.pop("modules_to_not_convert", None) - replace_embedding = kwargs.pop("replace_embedding", False) + cpu_embedding = kwargs.pop("cpu_embedding", False) + if kwargs.pop("replace_embedding", False): + warnings.warn("replace_embedding is deprecated and will be removed in a future version," + " please use cpu_embedding instead.", FutureWarning) + cpu_embedding = True + lightweight_bmm = kwargs.pop("lightweight_bmm", False) quant_config = kwargs.pop("quantization_config", None) _args = copy.deepcopy(args) _kwargs = copy.deepcopy(kwargs) @@ -262,7 +269,7 @@ class _BaseAutoModelClass: model = model.to("cpu") model = ggml_convert_low_bit(model, qtype, optimize_model, modules_to_not_convert=modules_to_not_convert, - replace_embedding=replace_embedding) + cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm) model.config.update({"bigdl_transformers_low_bit": q_k}) model.config.update({"tie_word_embeddings": False}) @@ -299,7 +306,12 @@ class _BaseAutoModelClass: import os modules_to_not_convert = kwargs.pop("modules_to_not_convert", None) - replace_embedding = kwargs.pop("replace_embedding", False) + cpu_embedding = kwargs.pop("cpu_embedding", False) + if kwargs.pop("replace_embedding", False): + warnings.warn("replace_embedding is deprecated and will be removed in a future version," + " please use cpu_embedding instead.", FutureWarning) + cpu_embedding = True + lightweight_bmm = kwargs.pop("lightweight_bmm", False) # Autofactory trust_remote_code = kwargs.pop("trust_remote_code", None) kwargs_orig = copy.deepcopy(kwargs) @@ -411,7 +423,7 @@ class _BaseAutoModelClass: quant_device = "meta" if bigdl_lcmu_enabled else "cpu" model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device, modules_to_not_convert=modules_to_not_convert, - replace_embedding=replace_embedding) + cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm) if is_sharded: loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]