Optimize transformer int4 loading memory (#8400)
* Optimize transformer int4 loading memory * move cast to convert * default settting low_cpu_mem_usage
This commit is contained in:
parent
2da21163f8
commit
449aea7ffc
3 changed files with 8 additions and 3 deletions
|
|
@ -71,6 +71,8 @@ def _replace_with_int4_linear(model, modules_to_not_convert=None, current_key_na
|
|||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
|
||||
module.weight = None
|
||||
|
||||
# Remove the last key for recursion
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_int4_linear(
|
||||
|
|
@ -93,4 +95,6 @@ def ggml_convert_int4(model):
|
|||
"instead of Linear layers. Please double check your model architecture, or submit "
|
||||
"an issue on github if you think this is a bug."
|
||||
)
|
||||
else:
|
||||
model.to(torch.float32)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -197,4 +197,4 @@ class LinearInt4(nn.Linear):
|
|||
if self.bias is not None:
|
||||
result += self.bias
|
||||
|
||||
return result
|
||||
return result.to(x.dtype)
|
||||
|
|
|
|||
|
|
@ -27,13 +27,14 @@ class _BaseAutoModelClass:
|
|||
*args,
|
||||
**kwargs):
|
||||
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
||||
if load_in_4bit:
|
||||
kwargs["low_cpu_mem_usage"] = True
|
||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||
|
||||
if load_in_4bit:
|
||||
from .convert import ggml_convert_int4
|
||||
model = model.to("cpu", torch.float32)
|
||||
model = model.to("cpu")
|
||||
model = ggml_convert_int4(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue