optimize NormHead for Baichuan2 (#9205)
* optimize NormHead for Baichuan2 * fix ut and change name * rename functions
This commit is contained in:
		
							parent
							
								
									a3b664ed03
								
							
						
					
					
						commit
						6dad8d16df
					
				
					 1 changed files with 28 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -121,10 +121,36 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
    return model, has_been_replaced
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _optimize_pre(model):
 | 
			
		||||
    from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
    # All huggingface format models are inherited from `PreTrainedModel`
 | 
			
		||||
    if not isinstance(model, PreTrainedModel):
 | 
			
		||||
        logger.info("Only HuggingFace Transformers models are currently "
 | 
			
		||||
                    "supported for further optimizations")
 | 
			
		||||
        return model
 | 
			
		||||
    # process NormHead module in Baichuan2 7B and 13B
 | 
			
		||||
    if model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
 | 
			
		||||
        # NormHead do normalization on the weights just once at inference time.
 | 
			
		||||
        # so we do it in advance and convert it to Linear so that it can be replaced.
 | 
			
		||||
        # modeling_module_name = model.__class__.__module__
 | 
			
		||||
        # module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        if hasattr(model, 'lm_head') and model.lm_head is not None:
 | 
			
		||||
            # do we need to check the class instance?
 | 
			
		||||
            vocab_size, hidden_size = model.lm_head.weight.shape
 | 
			
		||||
            norm_weight = nn.functional.normalize(model.lm_head.weight.data)
 | 
			
		||||
            model.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
 | 
			
		||||
            model.lm_head.weight.data = norm_weight
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		||||
                         convert_shape_only=False, device="cpu",
 | 
			
		||||
                         modules_to_not_convert=None):
 | 
			
		||||
    modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
 | 
			
		||||
 | 
			
		||||
    if optimize_model:
 | 
			
		||||
        model = _optimize_pre(model)
 | 
			
		||||
 | 
			
		||||
    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
			
		||||
        model, qtype, modules_to_not_convert,
 | 
			
		||||
        None, convert_shape_only,
 | 
			
		||||
| 
						 | 
				
			
			@ -143,7 +169,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
        pass
 | 
			
		||||
 | 
			
		||||
    if optimize_model:
 | 
			
		||||
        model = optimize(model)
 | 
			
		||||
        model = _optimize_post(model)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -155,7 +181,7 @@ def convert_forward(m, target_m, new_forward):
 | 
			
		|||
        convert_forward(sub_m, target_m, new_forward)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def optimize(model):
 | 
			
		||||
def _optimize_post(model):
 | 
			
		||||
    from packaging import version
 | 
			
		||||
    from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
 | 
			
		||||
    from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue