LLM: Enable attempting loading method automatically (#8841)
* enable auto load method * warning error * logger info --------- Co-authored-by: leonardozcm <leonardozcm@gmail.com>
This commit is contained in:
parent
bba73ec9d2
commit
731916c639
2 changed files with 29 additions and 5 deletions
|
|
@ -85,12 +85,13 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
|
||||||
|
|
||||||
# Remove the last key for recursion
|
# Remove the last key for recursion
|
||||||
if len(list(module.children())) > 0:
|
if len(list(module.children())) > 0:
|
||||||
_, has_been_replaced = _replace_with_quant_linear(
|
_, _flag = _replace_with_quant_linear(
|
||||||
module,
|
module,
|
||||||
qtype,
|
qtype,
|
||||||
modules_to_not_convert,
|
modules_to_not_convert,
|
||||||
current_key_name,
|
current_key_name,
|
||||||
)
|
)
|
||||||
|
has_been_replaced = _flag or has_been_replaced
|
||||||
return model, has_been_replaced
|
return model, has_been_replaced
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,11 @@ from .utils import extract_local_archive_file, \
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
import torch
|
import torch
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def save_low_bit(self, *args, **kwargs):
|
def save_low_bit(self, *args, **kwargs):
|
||||||
|
|
@ -98,7 +103,19 @@ class _BaseAutoModelClass:
|
||||||
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
||||||
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
||||||
qtype = ggml_tensor_qtype[q_k]
|
qtype = ggml_tensor_qtype[q_k]
|
||||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
# In case it needs a second try,
|
||||||
|
# `from_pretrained`` may pop items out in dict
|
||||||
|
# and lead to args missing.
|
||||||
|
_args = copy.deepcopy(args)
|
||||||
|
_kwargs = copy.deepcopy(kwargs)
|
||||||
|
try:
|
||||||
|
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||||
|
except NotImplementedError:
|
||||||
|
logger.info("Failed to load models with `low_cpu_mem_usage` specified, "
|
||||||
|
"will fall to traditional load method with higher memory consumption.")
|
||||||
|
_kwargs["low_cpu_mem_usage"] = False
|
||||||
|
model = cls.HF_Model.from_pretrained(*_args, **_kwargs)
|
||||||
|
model.config.update({"bigdl_lcmu_enabled": False})
|
||||||
model = model.to("cpu")
|
model = model.to("cpu")
|
||||||
model = ggml_convert_quant(model, qtype, optimize_model)
|
model = ggml_convert_quant(model, qtype, optimize_model)
|
||||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||||
|
|
@ -151,6 +168,7 @@ class _BaseAutoModelClass:
|
||||||
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
||||||
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
|
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
|
||||||
|
bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True)
|
||||||
|
|
||||||
invalidInputError(bigdl_transformers_low_bit,
|
invalidInputError(bigdl_transformers_low_bit,
|
||||||
"Detect this model is not a low-bit model, Please use from_pretrained"
|
"Detect this model is not a low-bit model, Please use from_pretrained"
|
||||||
|
|
@ -226,10 +244,15 @@ class _BaseAutoModelClass:
|
||||||
init_contexts = [no_init_weights(_enable=_fast_init)]
|
init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||||
init_contexts.append(init_empty_weights())
|
init_contexts.append(init_empty_weights())
|
||||||
|
|
||||||
with ContextManagers(init_contexts):
|
if bigdl_lcmu_enabled:
|
||||||
|
with ContextManagers(init_contexts):
|
||||||
|
model = model_class(config, *model_args, **kwargs)
|
||||||
|
else:
|
||||||
model = model_class(config, *model_args, **kwargs)
|
model = model_class(config, *model_args, **kwargs)
|
||||||
|
|
||||||
model = ggml_convert_quant(model, qtype, optimize_model, device="meta")
|
# Loading args may differ based on their usage
|
||||||
|
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
|
||||||
|
model = ggml_convert_quant(model, qtype, optimize_model, device=quant_device)
|
||||||
|
|
||||||
if is_sharded:
|
if is_sharded:
|
||||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||||
|
|
@ -260,7 +283,7 @@ class _BaseAutoModelClass:
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
sharded_metadata=sharded_metadata,
|
sharded_metadata=sharded_metadata,
|
||||||
_fast_init=_fast_init,
|
_fast_init=_fast_init,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=bigdl_lcmu_enabled,
|
||||||
offload_folder=offload_folder,
|
offload_folder=offload_folder,
|
||||||
offload_state_dict=offload_state_dict,
|
offload_state_dict=offload_state_dict,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue