Optimize transformer int4 loading memory (#8400)
* Optimize transformer int4 loading memory * move cast to convert * default settting low_cpu_mem_usage
This commit is contained in:
parent
2da21163f8
commit
449aea7ffc
3 changed files with 8 additions and 3 deletions
|
|
@ -71,6 +71,8 @@ def _replace_with_int4_linear(model, modules_to_not_convert=None, current_key_na
|
||||||
# Force requires grad to False to avoid unexpected errors
|
# Force requires grad to False to avoid unexpected errors
|
||||||
model._modules[name].requires_grad_(False)
|
model._modules[name].requires_grad_(False)
|
||||||
|
|
||||||
|
module.weight = None
|
||||||
|
|
||||||
# Remove the last key for recursion
|
# Remove the last key for recursion
|
||||||
if len(list(module.children())) > 0:
|
if len(list(module.children())) > 0:
|
||||||
_, has_been_replaced = _replace_with_int4_linear(
|
_, has_been_replaced = _replace_with_int4_linear(
|
||||||
|
|
@ -93,4 +95,6 @@ def ggml_convert_int4(model):
|
||||||
"instead of Linear layers. Please double check your model architecture, or submit "
|
"instead of Linear layers. Please double check your model architecture, or submit "
|
||||||
"an issue on github if you think this is a bug."
|
"an issue on github if you think this is a bug."
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
model.to(torch.float32)
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
|
|
@ -197,4 +197,4 @@ class LinearInt4(nn.Linear):
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
result += self.bias
|
result += self.bias
|
||||||
|
|
||||||
return result
|
return result.to(x.dtype)
|
||||||
|
|
|
||||||
|
|
@ -27,13 +27,14 @@ class _BaseAutoModelClass:
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
||||||
|
if load_in_4bit:
|
||||||
|
kwargs["low_cpu_mem_usage"] = True
|
||||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
if load_in_4bit:
|
if load_in_4bit:
|
||||||
from .convert import ggml_convert_int4
|
from .convert import ggml_convert_int4
|
||||||
model = model.to("cpu", torch.float32)
|
model = model.to("cpu")
|
||||||
model = ggml_convert_int4(model)
|
model = ggml_convert_int4(model)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue