Optimize transformer int4 loading memory (#8400)

* Optimize transformer int4 loading memory

* move cast to convert

* default settting low_cpu_mem_usage
This commit is contained in:
Yang Wang 2023-06-30 20:12:12 -07:00 committed by GitHub
parent 2da21163f8
commit 449aea7ffc
3 changed files with 8 additions and 3 deletions

View file

@ -71,6 +71,8 @@ def _replace_with_int4_linear(model, modules_to_not_convert=None, current_key_na
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
module.weight = None
# Remove the last key for recursion
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_int4_linear(
@ -93,4 +95,6 @@ def ggml_convert_int4(model):
"instead of Linear layers. Please double check your model architecture, or submit "
"an issue on github if you think this is a bug."
)
else:
model.to(torch.float32)
return model

View file

@ -197,4 +197,4 @@ class LinearInt4(nn.Linear):
if self.bias is not None:
result += self.bias
return result
return result.to(x.dtype)

View file

@ -27,13 +27,14 @@ class _BaseAutoModelClass:
*args,
**kwargs):
load_in_4bit = kwargs.pop("load_in_4bit", False)
if load_in_4bit:
kwargs["low_cpu_mem_usage"] = True
model = cls.HF_Model.from_pretrained(*args, **kwargs)
if load_in_4bit:
from .convert import ggml_convert_int4
model = model.to("cpu", torch.float32)
model = model.to("cpu")
model = ggml_convert_int4(model)
return model