Reduce max_cache_pos to reduce Baichuan2-13B memory (#9694)
* optimize baichuan2 memory * fix * style * fp16 mask * disable fp16 * fix style * empty cache * revert empty cache
This commit is contained in:
		
							parent
							
								
									361781bcd0
								
							
						
					
					
						commit
						689889482c
					
				
					 2 changed files with 107 additions and 7 deletions
				
			
		| 
						 | 
				
			
			@ -399,6 +399,14 @@ def convert_forward(m, target_m, new_forward):
 | 
			
		|||
        convert_forward(sub_m, target_m, new_forward)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def replace_func(m, target_m, func_name, new_func):
 | 
			
		||||
    for _, sub_m in m.named_children():
 | 
			
		||||
        if isinstance(sub_m, target_m):
 | 
			
		||||
            bound_method = new_func.__get__(sub_m, sub_m.__class__)
 | 
			
		||||
            setattr(sub_m, func_name, bound_method)
 | 
			
		||||
        replace_func(sub_m, target_m, func_name, new_func)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _optimize_post(model, lightweight_bmm=False):
 | 
			
		||||
    from packaging import version
 | 
			
		||||
    from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
 | 
			
		||||
| 
						 | 
				
			
			@ -569,6 +577,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
            from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
 | 
			
		||||
            from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
 | 
			
		||||
            from bigdl.llm.transformers.models.baichuan2 import baichuan_mlp_forward
 | 
			
		||||
            from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_get_alibi_mask
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.BaichuanAttention,
 | 
			
		||||
                            baichuan_attention_forward_13b
 | 
			
		||||
| 
						 | 
				
			
			@ -580,6 +589,10 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
            convert_forward(model,
 | 
			
		||||
                            module.MLP,
 | 
			
		||||
                            baichuan_mlp_forward)
 | 
			
		||||
            replace_func(model,
 | 
			
		||||
                         module.BaichuanModel,
 | 
			
		||||
                         "get_alibi_mask",
 | 
			
		||||
                         baichuan_13b_get_alibi_mask)
 | 
			
		||||
    elif model.config.model_type == "baichuan":
 | 
			
		||||
        # baichuan1
 | 
			
		||||
        if model.config.hidden_size == 4096:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,19 +19,15 @@
 | 
			
		|||
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
from typing import List, Optional, Tuple, Union
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
import torch
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from torch import nn
 | 
			
		||||
from torch.nn import functional as F
 | 
			
		||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from transformers.utils import logging, ContextManagers
 | 
			
		||||
from bigdl.llm.transformers.models.llama import get_ipex_version
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
| 
						 | 
				
			
			@ -301,3 +297,94 @@ def baichuan_attention_forward_13b(
 | 
			
		|||
        attn_weights = None
 | 
			
		||||
 | 
			
		||||
    return attn_output, attn_weights, past_key_value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_interleave(n):
 | 
			
		||||
    def _get_interleave_power_of_2(n):
 | 
			
		||||
        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
 | 
			
		||||
        ratio = start
 | 
			
		||||
        return [start * ratio**i for i in range(n)]
 | 
			
		||||
 | 
			
		||||
    if math.log2(n).is_integer():
 | 
			
		||||
        return _get_interleave_power_of_2(n)
 | 
			
		||||
    else:
 | 
			
		||||
        closest_power_of_2 = 2 ** math.floor(math.log2(n))
 | 
			
		||||
        return (
 | 
			
		||||
            _get_interleave_power_of_2(closest_power_of_2)
 | 
			
		||||
            + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _fill_with_neg_inf(t):
 | 
			
		||||
    """FP16-compatible function that fills a tensor with -inf."""
 | 
			
		||||
    return t.float().fill_(float("-inf")).type_as(t)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
 | 
			
		||||
    _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
 | 
			
		||||
    _future_mask = _future_mask.unsqueeze(0) + alibi
 | 
			
		||||
    new_future_mask = _future_mask.to(tensor)
 | 
			
		||||
    return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos):
 | 
			
		||||
    # May use fp16 for alibi mask to further reduce memory
 | 
			
		||||
    slopes = torch.Tensor(_get_interleave(n_head))  # .half()
 | 
			
		||||
    position_point = torch.arange(max_pos) - max_pos + 1
 | 
			
		||||
    position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
 | 
			
		||||
    diag = torch.diag(position_point[0])
 | 
			
		||||
    position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
 | 
			
		||||
    alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
 | 
			
		||||
    alibi = alibi.view(n_head, 1, max_pos)
 | 
			
		||||
    alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)  # .half()
 | 
			
		||||
    alibi_mask = alibi_mask.unsqueeze(0) + alibi
 | 
			
		||||
    if tensor.device.type == "xpu":
 | 
			
		||||
        alibi_mask = alibi_mask.to(tensor.device)
 | 
			
		||||
    return alibi_mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
MASK_BLOCK_SIZE = 64
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past):
 | 
			
		||||
    if self.training:
 | 
			
		||||
        slopes = torch.Tensor(_get_interleave(self.n_head))
 | 
			
		||||
        position_point = (
 | 
			
		||||
            torch.arange(seq_length_with_past) - seq_length_with_past + 1
 | 
			
		||||
        )
 | 
			
		||||
        position_point = (
 | 
			
		||||
            position_point.unsqueeze(0)
 | 
			
		||||
            .unsqueeze(0)
 | 
			
		||||
            .expand(self.n_head, seq_length_with_past, -1)
 | 
			
		||||
        )
 | 
			
		||||
        diag = torch.diag(position_point[0])
 | 
			
		||||
        position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(
 | 
			
		||||
            -1, -2
 | 
			
		||||
        )
 | 
			
		||||
        alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
 | 
			
		||||
        mask = _buffered_future_mask(
 | 
			
		||||
            tensor, seq_length_with_past, alibi, self.n_head
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        if self.first_run:
 | 
			
		||||
            # Override the default max_cache_pos=4096 for memory considerations
 | 
			
		||||
            self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
 | 
			
		||||
            self.first_run = False
 | 
			
		||||
            self.register_buffer(
 | 
			
		||||
                "future_mask",
 | 
			
		||||
                baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
 | 
			
		||||
                persistent=False,
 | 
			
		||||
            )
 | 
			
		||||
        if seq_length_with_past > self.max_cache_pos:
 | 
			
		||||
            # When max_cache_pos is not enough for current sequence length,
 | 
			
		||||
            # increase by MASK_BLOCK_SIZE and recalculate future_mask.
 | 
			
		||||
            self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
 | 
			
		||||
            self.register_buffer(
 | 
			
		||||
                "future_mask",
 | 
			
		||||
                baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
 | 
			
		||||
                persistent=False,
 | 
			
		||||
            )
 | 
			
		||||
        mask = self.future_mask[
 | 
			
		||||
            : self.n_head, :seq_length_with_past, :seq_length_with_past
 | 
			
		||||
        ]
 | 
			
		||||
    return mask
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue