Optimize transformer int4 loading memory (#8400)
* Optimize transformer int4 loading memory * move cast to convert * default settting low_cpu_mem_usage
This commit is contained in:
		
							parent
							
								
									2da21163f8
								
							
						
					
					
						commit
						449aea7ffc
					
				
					 3 changed files with 8 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -71,6 +71,8 @@ def _replace_with_int4_linear(model, modules_to_not_convert=None, current_key_na
 | 
			
		|||
                    # Force requires grad to False to avoid unexpected errors
 | 
			
		||||
                    model._modules[name].requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
                    module.weight = None
 | 
			
		||||
 | 
			
		||||
        # Remove the last key for recursion
 | 
			
		||||
        if len(list(module.children())) > 0:
 | 
			
		||||
            _, has_been_replaced = _replace_with_int4_linear(
 | 
			
		||||
| 
						 | 
				
			
			@ -93,4 +95,6 @@ def ggml_convert_int4(model):
 | 
			
		|||
            "instead of Linear layers. Please double check your model architecture, or submit "
 | 
			
		||||
            "an issue on github if you think this is a bug."
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        model.to(torch.float32)
 | 
			
		||||
    return model
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -197,4 +197,4 @@ class LinearInt4(nn.Linear):
 | 
			
		|||
        if self.bias is not None:
 | 
			
		||||
            result += self.bias
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
        return result.to(x.dtype)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,13 +27,14 @@ class _BaseAutoModelClass:
 | 
			
		|||
                        *args,
 | 
			
		||||
                        **kwargs):
 | 
			
		||||
        load_in_4bit = kwargs.pop("load_in_4bit", False)
 | 
			
		||||
        if load_in_4bit:
 | 
			
		||||
            kwargs["low_cpu_mem_usage"] = True
 | 
			
		||||
        model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        if load_in_4bit:
 | 
			
		||||
            from .convert import ggml_convert_int4
 | 
			
		||||
            model = model.to("cpu", torch.float32)
 | 
			
		||||
            model = model.to("cpu")
 | 
			
		||||
            model = ggml_convert_int4(model)
 | 
			
		||||
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue