Optimize transformer int4 loading memory (#8400)
* Optimize transformer int4 loading memory * move cast to convert * default settting low_cpu_mem_usage
This commit is contained in:
		
							parent
							
								
									2da21163f8
								
							
						
					
					
						commit
						449aea7ffc
					
				
					 3 changed files with 8 additions and 3 deletions
				
			
		| 
						 | 
					@ -71,6 +71,8 @@ def _replace_with_int4_linear(model, modules_to_not_convert=None, current_key_na
 | 
				
			||||||
                    # Force requires grad to False to avoid unexpected errors
 | 
					                    # Force requires grad to False to avoid unexpected errors
 | 
				
			||||||
                    model._modules[name].requires_grad_(False)
 | 
					                    model._modules[name].requires_grad_(False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    module.weight = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Remove the last key for recursion
 | 
					        # Remove the last key for recursion
 | 
				
			||||||
        if len(list(module.children())) > 0:
 | 
					        if len(list(module.children())) > 0:
 | 
				
			||||||
            _, has_been_replaced = _replace_with_int4_linear(
 | 
					            _, has_been_replaced = _replace_with_int4_linear(
 | 
				
			||||||
| 
						 | 
					@ -93,4 +95,6 @@ def ggml_convert_int4(model):
 | 
				
			||||||
            "instead of Linear layers. Please double check your model architecture, or submit "
 | 
					            "instead of Linear layers. Please double check your model architecture, or submit "
 | 
				
			||||||
            "an issue on github if you think this is a bug."
 | 
					            "an issue on github if you think this is a bug."
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        model.to(torch.float32)
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -197,4 +197,4 @@ class LinearInt4(nn.Linear):
 | 
				
			||||||
        if self.bias is not None:
 | 
					        if self.bias is not None:
 | 
				
			||||||
            result += self.bias
 | 
					            result += self.bias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return result
 | 
					        return result.to(x.dtype)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -27,13 +27,14 @@ class _BaseAutoModelClass:
 | 
				
			||||||
                        *args,
 | 
					                        *args,
 | 
				
			||||||
                        **kwargs):
 | 
					                        **kwargs):
 | 
				
			||||||
        load_in_4bit = kwargs.pop("load_in_4bit", False)
 | 
					        load_in_4bit = kwargs.pop("load_in_4bit", False)
 | 
				
			||||||
 | 
					        if load_in_4bit:
 | 
				
			||||||
 | 
					            kwargs["low_cpu_mem_usage"] = True
 | 
				
			||||||
        model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
					        model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if load_in_4bit:
 | 
					        if load_in_4bit:
 | 
				
			||||||
            from .convert import ggml_convert_int4
 | 
					            from .convert import ggml_convert_int4
 | 
				
			||||||
            model = model.to("cpu", torch.float32)
 | 
					            model = model.to("cpu")
 | 
				
			||||||
            model = ggml_convert_int4(model)
 | 
					            model = ggml_convert_int4(model)
 | 
				
			||||||
 | 
					 | 
				
			||||||
        return model
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue