Replace torch.bmm with safe_bmm (#9519)
* replace bmm with safe one * rename args and deprecated warning
This commit is contained in:
parent
b3178d449f
commit
42b7a16bc5
4 changed files with 95 additions and 15 deletions
|
|
@ -26,6 +26,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.transformers.utils import extract_local_archive_file, get_local_shard_files
|
from bigdl.llm.transformers.utils import extract_local_archive_file, get_local_shard_files
|
||||||
import transformers
|
import transformers
|
||||||
|
import warnings
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from .utils.common import MuteHFLogger
|
from .utils.common import MuteHFLogger
|
||||||
from .utils.lazy_load_torch import LazyLoadTensors
|
from .utils.lazy_load_torch import LazyLoadTensors
|
||||||
|
|
@ -193,7 +194,7 @@ def load_low_bit(model, model_path):
|
||||||
|
|
||||||
|
|
||||||
def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_convert=None,
|
def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_convert=None,
|
||||||
replace_embedding=False):
|
cpu_embedding=False, lightweight_bmm=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
A method to optimize any pytorch model.
|
A method to optimize any pytorch model.
|
||||||
|
|
||||||
|
|
@ -203,7 +204,9 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
|
||||||
:param optimize_llm: Whether to further optimize llm model.
|
:param optimize_llm: Whether to further optimize llm model.
|
||||||
:param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped
|
:param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped
|
||||||
when conducting model optimizations. Default to be None.
|
when conducting model optimizations. Default to be None.
|
||||||
:param replace_embedding: Whether to replace the Embedding layer, may need to set it
|
:param cpu_embedding: Whether to replace the Embedding layer, may need to set it
|
||||||
|
to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`.
|
||||||
|
:param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it
|
||||||
to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`.
|
to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`.
|
||||||
|
|
||||||
:return: The optimized model.
|
:return: The optimized model.
|
||||||
|
|
@ -226,12 +229,17 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
|
||||||
invalidInputError(model.device.type == 'cpu',
|
invalidInputError(model.device.type == 'cpu',
|
||||||
"Expect model on device `cpu`, "
|
"Expect model on device `cpu`, "
|
||||||
f"but got device type {model.device.type}")
|
f"but got device type {model.device.type}")
|
||||||
|
if kwargs.pop("replace_embedding", False):
|
||||||
|
warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
|
||||||
|
" please use cpu_embedding instead.", FutureWarning)
|
||||||
|
cpu_embedding = True
|
||||||
qtype = ggml_tensor_qtype[low_bit]
|
qtype = ggml_tensor_qtype[low_bit]
|
||||||
model = ggml_convert_low_bit(model,
|
model = ggml_convert_low_bit(model,
|
||||||
qtype=qtype,
|
qtype=qtype,
|
||||||
optimize_model=optimize_llm,
|
optimize_model=optimize_llm,
|
||||||
modules_to_not_convert=modules_to_not_convert,
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
replace_embedding=replace_embedding)
|
cpu_embedding=cpu_embedding,
|
||||||
|
lightweight_bmm=lightweight_bmm)
|
||||||
# add save_low_bit to pretrained model dynamically
|
# add save_low_bit to pretrained model dynamically
|
||||||
import types
|
import types
|
||||||
model._bigdl_config = dict()
|
model._bigdl_config = dict()
|
||||||
|
|
|
||||||
45
python/llm/src/bigdl/llm/transformers/bmm.py
Normal file
45
python/llm/src/bigdl/llm/transformers/bmm.py
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import linear_q4_0
|
||||||
|
torch_bmm_old_ = torch.bmm
|
||||||
|
|
||||||
|
|
||||||
|
def torch_bmm(a, b):
|
||||||
|
if a.device.type == 'cpu':
|
||||||
|
return torch_bmm_old_(a, b)
|
||||||
|
|
||||||
|
batch, A_rows, common = a.size()
|
||||||
|
B_cols = b.size(2)
|
||||||
|
C = torch.empty((batch, A_rows, B_cols), device=a.device)
|
||||||
|
if a.size(1) == 1:
|
||||||
|
torch_bmm_old_(a, b, out=C)
|
||||||
|
else:
|
||||||
|
linear_q4_0.bmm(a.contiguous(), b.contiguous(), C)
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
class SafeBMM:
|
||||||
|
def __init__(self):
|
||||||
|
self._old_bmm = torch_bmm_old_
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
torch.bmm = torch_bmm
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
torch.bmm = self._old_bmm
|
||||||
|
|
@ -172,7 +172,7 @@ def convert_gptq(module, awq=False):
|
||||||
|
|
||||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
current_key_name=None, convert_shape_only=False,
|
current_key_name=None, convert_shape_only=False,
|
||||||
replace_embedding=False):
|
cpu_embedding=False):
|
||||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear
|
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear
|
||||||
from bigdl.llm.transformers.embedding import LLMEmbedding
|
from bigdl.llm.transformers.embedding import LLMEmbedding
|
||||||
has_been_replaced = False
|
has_been_replaced = False
|
||||||
|
|
@ -265,7 +265,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
model._modules[name].requires_grad_(False)
|
model._modules[name].requires_grad_(False)
|
||||||
|
|
||||||
module.weight = None
|
module.weight = None
|
||||||
elif replace_embedding and type(module) == nn.Embedding:
|
elif cpu_embedding and type(module) == nn.Embedding:
|
||||||
# skip user-defined Embedding layer
|
# skip user-defined Embedding layer
|
||||||
if platform.system().lower() == 'windows':
|
if platform.system().lower() == 'windows':
|
||||||
model._modules[name] = LLMEmbedding(
|
model._modules[name] = LLMEmbedding(
|
||||||
|
|
@ -287,7 +287,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
modules_to_not_convert,
|
modules_to_not_convert,
|
||||||
current_key_name,
|
current_key_name,
|
||||||
convert_shape_only,
|
convert_shape_only,
|
||||||
replace_embedding,
|
cpu_embedding,
|
||||||
)
|
)
|
||||||
has_been_replaced = _flag or has_been_replaced
|
has_been_replaced = _flag or has_been_replaced
|
||||||
return model, has_been_replaced
|
return model, has_been_replaced
|
||||||
|
|
@ -321,7 +321,8 @@ def _optimize_pre(model):
|
||||||
|
|
||||||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
convert_shape_only=False, device="cpu",
|
convert_shape_only=False, device="cpu",
|
||||||
modules_to_not_convert=None, replace_embedding=False):
|
modules_to_not_convert=None, cpu_embedding=False,
|
||||||
|
lightweight_bmm=False):
|
||||||
logger.info(f"Converting the current model to "
|
logger.info(f"Converting the current model to "
|
||||||
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
|
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
|
||||||
f"format......")
|
f"format......")
|
||||||
|
|
@ -332,7 +333,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
|
|
||||||
model, has_been_replaced = _replace_with_low_bit_linear(
|
model, has_been_replaced = _replace_with_low_bit_linear(
|
||||||
model, qtype, modules_to_not_convert,
|
model, qtype, modules_to_not_convert,
|
||||||
None, convert_shape_only, replace_embedding,
|
None, convert_shape_only, cpu_embedding,
|
||||||
)
|
)
|
||||||
if not has_been_replaced:
|
if not has_been_replaced:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
@ -349,7 +350,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if optimize_model:
|
if optimize_model:
|
||||||
model = _optimize_post(model)
|
model = _optimize_post(model, lightweight_bmm)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -361,7 +362,7 @@ def convert_forward(m, target_m, new_forward):
|
||||||
convert_forward(sub_m, target_m, new_forward)
|
convert_forward(sub_m, target_m, new_forward)
|
||||||
|
|
||||||
|
|
||||||
def _optimize_post(model):
|
def _optimize_post(model, lightweight_bmm=False):
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
||||||
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
||||||
|
|
@ -593,4 +594,18 @@ def _optimize_post(model):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.MistralRMSNorm,
|
module.MistralRMSNorm,
|
||||||
llama_rms_norm_forward)
|
llama_rms_norm_forward)
|
||||||
|
elif model.config.model_type == "whisper" and lightweight_bmm:
|
||||||
|
if platform.system().lower() == 'windows':
|
||||||
|
from bigdl.llm.transformers.bmm import SafeBMM
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
old_fwd = module.WhisperAttention.forward
|
||||||
|
|
||||||
|
def safe_bmm_fwd(*args, **kwargs):
|
||||||
|
with SafeBMM():
|
||||||
|
return old_fwd(*args, **kwargs)
|
||||||
|
|
||||||
|
convert_forward(model,
|
||||||
|
module.WhisperAttention,
|
||||||
|
safe_bmm_fwd)
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,9 @@ class _BaseAutoModelClass:
|
||||||
Default to be True.
|
Default to be True.
|
||||||
:param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when
|
:param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when
|
||||||
conducting model optimizations. Default to be None.
|
conducting model optimizations. Default to be None.
|
||||||
:param replace_embedding: Whether to replace the Embedding layer, may need to set it
|
:param cpu_embedding: Whether to replace the Embedding layer, may need to set it
|
||||||
|
to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`.
|
||||||
|
:param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it
|
||||||
to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`.
|
to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`.
|
||||||
:return: a model instance
|
:return: a model instance
|
||||||
"""
|
"""
|
||||||
|
|
@ -201,7 +203,12 @@ class _BaseAutoModelClass:
|
||||||
# `from_pretrained`` may pop items out in dict
|
# `from_pretrained`` may pop items out in dict
|
||||||
# and lead to args missing.
|
# and lead to args missing.
|
||||||
modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
|
modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
|
||||||
replace_embedding = kwargs.pop("replace_embedding", False)
|
cpu_embedding = kwargs.pop("cpu_embedding", False)
|
||||||
|
if kwargs.pop("replace_embedding", False):
|
||||||
|
warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
|
||||||
|
" please use cpu_embedding instead.", FutureWarning)
|
||||||
|
cpu_embedding = True
|
||||||
|
lightweight_bmm = kwargs.pop("lightweight_bmm", False)
|
||||||
quant_config = kwargs.pop("quantization_config", None)
|
quant_config = kwargs.pop("quantization_config", None)
|
||||||
_args = copy.deepcopy(args)
|
_args = copy.deepcopy(args)
|
||||||
_kwargs = copy.deepcopy(kwargs)
|
_kwargs = copy.deepcopy(kwargs)
|
||||||
|
|
@ -262,7 +269,7 @@ class _BaseAutoModelClass:
|
||||||
model = model.to("cpu")
|
model = model.to("cpu")
|
||||||
model = ggml_convert_low_bit(model, qtype, optimize_model,
|
model = ggml_convert_low_bit(model, qtype, optimize_model,
|
||||||
modules_to_not_convert=modules_to_not_convert,
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
replace_embedding=replace_embedding)
|
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm)
|
||||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||||
model.config.update({"tie_word_embeddings": False})
|
model.config.update({"tie_word_embeddings": False})
|
||||||
|
|
||||||
|
|
@ -299,7 +306,12 @@ class _BaseAutoModelClass:
|
||||||
import os
|
import os
|
||||||
|
|
||||||
modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
|
modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
|
||||||
replace_embedding = kwargs.pop("replace_embedding", False)
|
cpu_embedding = kwargs.pop("cpu_embedding", False)
|
||||||
|
if kwargs.pop("replace_embedding", False):
|
||||||
|
warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
|
||||||
|
" please use cpu_embedding instead.", FutureWarning)
|
||||||
|
cpu_embedding = True
|
||||||
|
lightweight_bmm = kwargs.pop("lightweight_bmm", False)
|
||||||
# Autofactory
|
# Autofactory
|
||||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||||
kwargs_orig = copy.deepcopy(kwargs)
|
kwargs_orig = copy.deepcopy(kwargs)
|
||||||
|
|
@ -411,7 +423,7 @@ class _BaseAutoModelClass:
|
||||||
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
|
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
|
||||||
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
|
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
|
||||||
modules_to_not_convert=modules_to_not_convert,
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
replace_embedding=replace_embedding)
|
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm)
|
||||||
|
|
||||||
if is_sharded:
|
if is_sharded:
|
||||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue