remove bmm, which is only required in ipex 2.0 (#12630)
This commit is contained in:
parent
f17ccfa61a
commit
2d08155513
7 changed files with 13 additions and 87 deletions
|
|
@ -3,7 +3,7 @@
|
|||
## Optimize Model
|
||||
You can run any PyTorch model with `optimize_model` through only one-line code change to benefit from IPEX-LLM optimization, regardless of the library or API you are using.
|
||||
|
||||
### `ipex_llm.optimize_model`_`(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_convert=None, cpu_embedding=False, lightweight_bmm=False, **kwargs)`_
|
||||
### `ipex_llm.optimize_model`_`(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_convert=None, cpu_embedding=False, **kwargs)`_
|
||||
|
||||
A method to optimize any pytorch model.
|
||||
|
||||
|
|
@ -19,8 +19,6 @@ A method to optimize any pytorch model.
|
|||
|
||||
- **cpu_embedding**: Whether to replace the Embedding layer, may need to set it to `True` when running IPEX-LLM on GPU. Default to be `False`.
|
||||
|
||||
- **lightweight_bmm**: Whether to replace the `torch.bmm` ops, may need to set it to `True` when running IPEX-LLM on GPU on Windows. Default to be `False`.
|
||||
|
||||
- **Returns**: The optimized model.
|
||||
|
||||
- **Example**:
|
||||
|
|
@ -76,4 +74,4 @@ Load the optimized pytorch model.
|
|||
from ipex_llm.optimize import load_low_bit
|
||||
model = whisper.load_model('tiny') # A model instance through traditional loading method
|
||||
model = load_low_bit(model, saved_dir) # Load the optimized model
|
||||
```
|
||||
```
|
||||
|
|
|
|||
|
|
@ -29,8 +29,6 @@ Three new arguments are added to extend Hugging Face’s from_pretrained method
|
|||
|
||||
- **cpu_embedding**: Whether to replace the Embedding layer, may need to set it to `True` when running IPEX-LLM on GPU. Default to be `False`.
|
||||
|
||||
- **lightweight_bmm**: Whether to replace the torch.bmm ops, may need to set it to `True` when running IPEX-LLM on GPU on Windows. Default to be `False`.
|
||||
|
||||
- **imatrix**: `str` value, represent filename of importance matrix pretrained on specific datasets for use with the improved quantization methods recently added to llama.cpp.
|
||||
|
||||
- **model_hub**: `str` value, options are `'huggingface'` and `'modelscope'`, specify the model hub. Default to be `'huggingface'`.
|
||||
|
|
@ -48,7 +46,7 @@ Three new arguments are added to extend Hugging Face’s from_pretrained method
|
|||
Load gguf model and tokenizer and convert it to bigdl-llm model and huggingface tokenzier
|
||||
|
||||
- **Parameters**:
|
||||
|
||||
|
||||
- **fpath**: Path to gguf model file
|
||||
|
||||
- **optimize_model**: Whether to further optimize llm model, defaults to `True`
|
||||
|
|
@ -64,7 +62,7 @@ Load gguf model and tokenizer and convert it to bigdl-llm model and huggingface
|
|||
Load a low bit optimized model (including INT4, INT5 and INT8) from a saved ckpt.
|
||||
|
||||
- **Parameters**:
|
||||
|
||||
|
||||
- **pretrained_model_name_or_path**: `str` value, Path to load the optimized model ckpt.
|
||||
|
||||
- **optimize_model**: `boolean` value, Whether to further optimize the low_bit llm model.
|
||||
|
|
|
|||
|
|
@ -195,7 +195,7 @@ def load_low_bit(model, model_path):
|
|||
|
||||
|
||||
def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_convert=None,
|
||||
cpu_embedding=False, lightweight_bmm=False, **kwargs):
|
||||
cpu_embedding=False, **kwargs):
|
||||
"""
|
||||
A method to optimize any pytorch model.
|
||||
|
||||
|
|
@ -211,8 +211,6 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
|
|||
when conducting model optimizations. Default to be ``None``.
|
||||
:param cpu_embedding: Whether to replace the Embedding layer, may need to set it
|
||||
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
|
||||
:param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it
|
||||
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
|
||||
|
||||
:return: The optimized model.
|
||||
|
||||
|
|
@ -256,8 +254,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
|
|||
torch_dtype=torch_dtype,
|
||||
optimize_model=optimize_llm,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
cpu_embedding=cpu_embedding,
|
||||
lightweight_bmm=lightweight_bmm)
|
||||
cpu_embedding=cpu_embedding)
|
||||
# add save_low_bit to pretrained model dynamically
|
||||
import types
|
||||
model._bigdl_config = dict()
|
||||
|
|
|
|||
|
|
@ -1,45 +0,0 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import torch
|
||||
import xe_linear
|
||||
torch_bmm_old_ = torch.bmm
|
||||
|
||||
|
||||
def torch_bmm(a, b):
|
||||
if a.device.type == 'cpu':
|
||||
return torch_bmm_old_(a, b)
|
||||
|
||||
batch, A_rows, common = a.size()
|
||||
B_cols = b.size(2)
|
||||
C = torch.empty((batch, A_rows, B_cols), device=a.device)
|
||||
if a.size(1) == 1:
|
||||
torch_bmm_old_(a, b, out=C)
|
||||
else:
|
||||
xe_linear.bmm(a.contiguous(), b.contiguous(), C)
|
||||
return C
|
||||
|
||||
|
||||
class SafeBMM:
|
||||
def __init__(self):
|
||||
self._old_bmm = torch_bmm_old_
|
||||
|
||||
def __enter__(self):
|
||||
torch.bmm = torch_bmm
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
torch.bmm = self._old_bmm
|
||||
|
|
@ -1078,7 +1078,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
convert_shape_only=False, device="cpu",
|
||||
modules_to_not_convert=None,
|
||||
cpu_embedding=False,
|
||||
lightweight_bmm=False, torch_dtype="auto",
|
||||
torch_dtype="auto",
|
||||
imatrix_data=None,
|
||||
embedding_qtype=None,
|
||||
mixed_precision=False):
|
||||
|
|
@ -1146,7 +1146,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
pass
|
||||
|
||||
if optimize_model:
|
||||
model = _optimize_post(model, lightweight_bmm)
|
||||
model = _optimize_post(model)
|
||||
|
||||
if hasattr(model, "config") and hasattr(model.config, "model_type") and \
|
||||
model.config.model_type == "qwen" and hasattr(model.config, "visual"):
|
||||
|
|
@ -1247,7 +1247,7 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
|
|||
return _ipex_jit(model)
|
||||
|
||||
|
||||
def _optimize_post(model, lightweight_bmm=False):
|
||||
def _optimize_post(model):
|
||||
try:
|
||||
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
|
|
@ -1627,7 +1627,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding)
|
||||
vision_module = importlib.import_module(vision_model.__class__.__module__)
|
||||
convert_forward(vision_model, vision_module.InternAttention, intern_attention_forward)
|
||||
_optimize_post(model.language_model, lightweight_bmm=lightweight_bmm)
|
||||
_optimize_post(model.language_model)
|
||||
elif model.config.model_type == "qwen":
|
||||
if hasattr(model.config, "visual"):
|
||||
# for Qwen-VL-Chat
|
||||
|
|
@ -1731,7 +1731,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
module.Qwen2MoeSdpaAttention,
|
||||
qwen2_attention_forward)
|
||||
elif model.config.model_type == "qwen2_audio":
|
||||
_optimize_post(model.language_model, lightweight_bmm=lightweight_bmm)
|
||||
_optimize_post(model.language_model)
|
||||
elif model.config.model_type == "qwen2_vl":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
@ -1875,20 +1875,6 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
convert_forward(model, module.YiRMSNorm, rms_norm_forward)
|
||||
elif model.config.model_type == "whisper" and lightweight_bmm:
|
||||
if platform.system().lower() == 'windows':
|
||||
from ipex_llm.transformers.bmm import SafeBMM
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
old_fwd = module.WhisperAttention.forward
|
||||
|
||||
def safe_bmm_fwd(*args, **kwargs):
|
||||
with SafeBMM():
|
||||
return old_fwd(*args, **kwargs)
|
||||
|
||||
convert_forward(model,
|
||||
module.WhisperAttention,
|
||||
safe_bmm_fwd)
|
||||
elif model.config.model_type == "rwkv":
|
||||
# rwkv v4
|
||||
modeling_module_name = model.__class__.__module__
|
||||
|
|
@ -2081,7 +2067,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
elif model.config.hidden_size == 1536 and model.config.vocab_size == 73464:
|
||||
# MiniCPM-V ?
|
||||
model.llm.config.model_type = "minicpm"
|
||||
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
|
||||
_optimize_post(model.llm)
|
||||
model.llm.config.model_type = "minicpmv"
|
||||
|
||||
vpm_modeling_module_name = model.vpm.__class__.__module__
|
||||
|
|
@ -2135,7 +2121,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
# llm
|
||||
model.llm.config.model_type = "llama"
|
||||
model.llm.config.rope_scaling = {"rope_type": "default"}
|
||||
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
|
||||
_optimize_post(model.llm)
|
||||
model.llm.config.model_type = "megrezo"
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -147,8 +147,6 @@ class _BaseAutoModelClass:
|
|||
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
|
||||
:param disk_embedding: Whether to put the Embedding layer on disk to save memory.
|
||||
Default to be ``False``.
|
||||
:param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it
|
||||
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
|
||||
:param imatrix: str value, represent filename of importance matrix pretrained on
|
||||
specific datasets for use with the improved quantization methods recently
|
||||
added to llama.cpp.
|
||||
|
|
@ -441,7 +439,6 @@ class _BaseAutoModelClass:
|
|||
" please use cpu_embedding instead.", FutureWarning)
|
||||
cpu_embedding = True
|
||||
disk_embedding = kwargs.pop("disk_embedding", False)
|
||||
lightweight_bmm = kwargs.pop("lightweight_bmm", False)
|
||||
quant_config = kwargs.pop("quantization_config", None)
|
||||
imatrix_data = kwargs.pop("imatrix_data", None)
|
||||
embedding_qtype = kwargs.pop("embedding_qtype", None)
|
||||
|
|
@ -513,7 +510,6 @@ class _BaseAutoModelClass:
|
|||
model = ggml_convert_low_bit(model, qtype, optimize_model,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
cpu_embedding=cpu_embedding,
|
||||
lightweight_bmm=lightweight_bmm,
|
||||
torch_dtype=kwargs.get("torch_dtype", 'auto'),
|
||||
imatrix_data=imatrix_data,
|
||||
embedding_qtype=embedding_qtype,
|
||||
|
|
@ -576,7 +572,6 @@ class _BaseAutoModelClass:
|
|||
" please use cpu_embedding instead.", FutureWarning)
|
||||
cpu_embedding = True
|
||||
disk_embedding = kwargs.pop("disk_embedding", False)
|
||||
lightweight_bmm = kwargs.pop("lightweight_bmm", False)
|
||||
# Autofactory
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
kwargs_orig = copy.deepcopy(kwargs)
|
||||
|
|
@ -713,7 +708,6 @@ class _BaseAutoModelClass:
|
|||
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
cpu_embedding=cpu_embedding,
|
||||
lightweight_bmm=lightweight_bmm,
|
||||
embedding_qtype=embedding_qtype, torch_dtype=torch_dtype)
|
||||
|
||||
if is_sharded:
|
||||
|
|
|
|||
|
|
@ -116,7 +116,6 @@ class _BaseAutoModelClass:
|
|||
|
||||
# ignore following arguments
|
||||
ignore_argument(kwargs, "model_hub")
|
||||
ignore_argument(kwargs, "lightweight_bmm")
|
||||
ignore_argument(kwargs, "load_in_4bit")
|
||||
ignore_argument(kwargs, "load_in_8bit")
|
||||
ignore_argument(kwargs, "imatrix")
|
||||
|
|
@ -365,7 +364,6 @@ class _BaseAutoModelClass:
|
|||
def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
|
||||
# ignore following arguments
|
||||
ignore_argument(kwargs, "model_hub")
|
||||
ignore_argument(kwargs, "lightweight_bmm")
|
||||
ignore_argument(kwargs, "cpu_embedding")
|
||||
ignore_argument(kwargs, "embedding_qtype")
|
||||
ignore_argument(kwargs, "speculative")
|
||||
|
|
|
|||
Loading…
Reference in a new issue