Reduce max_cache_pos to reduce Baichuan2-13B memory (#9694)
* optimize baichuan2 memory * fix * style * fp16 mask * disable fp16 * fix style * empty cache * revert empty cache
This commit is contained in:
		
							parent
							
								
									361781bcd0
								
							
						
					
					
						commit
						689889482c
					
				
					 2 changed files with 107 additions and 7 deletions
				
			
		| 
						 | 
					@ -399,6 +399,14 @@ def convert_forward(m, target_m, new_forward):
 | 
				
			||||||
        convert_forward(sub_m, target_m, new_forward)
 | 
					        convert_forward(sub_m, target_m, new_forward)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def replace_func(m, target_m, func_name, new_func):
 | 
				
			||||||
 | 
					    for _, sub_m in m.named_children():
 | 
				
			||||||
 | 
					        if isinstance(sub_m, target_m):
 | 
				
			||||||
 | 
					            bound_method = new_func.__get__(sub_m, sub_m.__class__)
 | 
				
			||||||
 | 
					            setattr(sub_m, func_name, bound_method)
 | 
				
			||||||
 | 
					        replace_func(sub_m, target_m, func_name, new_func)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _optimize_post(model, lightweight_bmm=False):
 | 
					def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
    from packaging import version
 | 
					    from packaging import version
 | 
				
			||||||
    from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
 | 
					    from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
 | 
				
			||||||
| 
						 | 
					@ -569,6 +577,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
            from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
 | 
					            from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
 | 
				
			||||||
            from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
 | 
					            from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
 | 
				
			||||||
            from bigdl.llm.transformers.models.baichuan2 import baichuan_mlp_forward
 | 
					            from bigdl.llm.transformers.models.baichuan2 import baichuan_mlp_forward
 | 
				
			||||||
 | 
					            from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_get_alibi_mask
 | 
				
			||||||
            convert_forward(model,
 | 
					            convert_forward(model,
 | 
				
			||||||
                            module.BaichuanAttention,
 | 
					                            module.BaichuanAttention,
 | 
				
			||||||
                            baichuan_attention_forward_13b
 | 
					                            baichuan_attention_forward_13b
 | 
				
			||||||
| 
						 | 
					@ -580,6 +589,10 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
            convert_forward(model,
 | 
					            convert_forward(model,
 | 
				
			||||||
                            module.MLP,
 | 
					                            module.MLP,
 | 
				
			||||||
                            baichuan_mlp_forward)
 | 
					                            baichuan_mlp_forward)
 | 
				
			||||||
 | 
					            replace_func(model,
 | 
				
			||||||
 | 
					                         module.BaichuanModel,
 | 
				
			||||||
 | 
					                         "get_alibi_mask",
 | 
				
			||||||
 | 
					                         baichuan_13b_get_alibi_mask)
 | 
				
			||||||
    elif model.config.model_type == "baichuan":
 | 
					    elif model.config.model_type == "baichuan":
 | 
				
			||||||
        # baichuan1
 | 
					        # baichuan1
 | 
				
			||||||
        if model.config.hidden_size == 4096:
 | 
					        if model.config.hidden_size == 4096:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,19 +19,15 @@
 | 
				
			||||||
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
 | 
					# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
from typing import List, Optional, Tuple, Union
 | 
					from typing import Optional, Tuple
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.utils.checkpoint
 | 
					import torch.utils.checkpoint
 | 
				
			||||||
from torch import nn
 | 
					 | 
				
			||||||
from torch.nn import functional as F
 | 
					from torch.nn import functional as F
 | 
				
			||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
					 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					 | 
				
			||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
					from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
					from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
					from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
				
			||||||
from transformers.utils import logging, ContextManagers
 | 
					from transformers.utils import logging
 | 
				
			||||||
from bigdl.llm.transformers.models.llama import get_ipex_version
 | 
					 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
| 
						 | 
					@ -301,3 +297,94 @@ def baichuan_attention_forward_13b(
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return attn_output, attn_weights, past_key_value
 | 
					    return attn_output, attn_weights, past_key_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_interleave(n):
 | 
				
			||||||
 | 
					    def _get_interleave_power_of_2(n):
 | 
				
			||||||
 | 
					        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
 | 
				
			||||||
 | 
					        ratio = start
 | 
				
			||||||
 | 
					        return [start * ratio**i for i in range(n)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if math.log2(n).is_integer():
 | 
				
			||||||
 | 
					        return _get_interleave_power_of_2(n)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        closest_power_of_2 = 2 ** math.floor(math.log2(n))
 | 
				
			||||||
 | 
					        return (
 | 
				
			||||||
 | 
					            _get_interleave_power_of_2(closest_power_of_2)
 | 
				
			||||||
 | 
					            + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _fill_with_neg_inf(t):
 | 
				
			||||||
 | 
					    """FP16-compatible function that fills a tensor with -inf."""
 | 
				
			||||||
 | 
					    return t.float().fill_(float("-inf")).type_as(t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
 | 
				
			||||||
 | 
					    _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
 | 
				
			||||||
 | 
					    _future_mask = _future_mask.unsqueeze(0) + alibi
 | 
				
			||||||
 | 
					    new_future_mask = _future_mask.to(tensor)
 | 
				
			||||||
 | 
					    return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos):
 | 
				
			||||||
 | 
					    # May use fp16 for alibi mask to further reduce memory
 | 
				
			||||||
 | 
					    slopes = torch.Tensor(_get_interleave(n_head))  # .half()
 | 
				
			||||||
 | 
					    position_point = torch.arange(max_pos) - max_pos + 1
 | 
				
			||||||
 | 
					    position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
 | 
				
			||||||
 | 
					    diag = torch.diag(position_point[0])
 | 
				
			||||||
 | 
					    position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
 | 
				
			||||||
 | 
					    alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
 | 
				
			||||||
 | 
					    alibi = alibi.view(n_head, 1, max_pos)
 | 
				
			||||||
 | 
					    alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)  # .half()
 | 
				
			||||||
 | 
					    alibi_mask = alibi_mask.unsqueeze(0) + alibi
 | 
				
			||||||
 | 
					    if tensor.device.type == "xpu":
 | 
				
			||||||
 | 
					        alibi_mask = alibi_mask.to(tensor.device)
 | 
				
			||||||
 | 
					    return alibi_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MASK_BLOCK_SIZE = 64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past):
 | 
				
			||||||
 | 
					    if self.training:
 | 
				
			||||||
 | 
					        slopes = torch.Tensor(_get_interleave(self.n_head))
 | 
				
			||||||
 | 
					        position_point = (
 | 
				
			||||||
 | 
					            torch.arange(seq_length_with_past) - seq_length_with_past + 1
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        position_point = (
 | 
				
			||||||
 | 
					            position_point.unsqueeze(0)
 | 
				
			||||||
 | 
					            .unsqueeze(0)
 | 
				
			||||||
 | 
					            .expand(self.n_head, seq_length_with_past, -1)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        diag = torch.diag(position_point[0])
 | 
				
			||||||
 | 
					        position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(
 | 
				
			||||||
 | 
					            -1, -2
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
 | 
				
			||||||
 | 
					        mask = _buffered_future_mask(
 | 
				
			||||||
 | 
					            tensor, seq_length_with_past, alibi, self.n_head
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if self.first_run:
 | 
				
			||||||
 | 
					            # Override the default max_cache_pos=4096 for memory considerations
 | 
				
			||||||
 | 
					            self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
 | 
				
			||||||
 | 
					            self.first_run = False
 | 
				
			||||||
 | 
					            self.register_buffer(
 | 
				
			||||||
 | 
					                "future_mask",
 | 
				
			||||||
 | 
					                baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
 | 
				
			||||||
 | 
					                persistent=False,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        if seq_length_with_past > self.max_cache_pos:
 | 
				
			||||||
 | 
					            # When max_cache_pos is not enough for current sequence length,
 | 
				
			||||||
 | 
					            # increase by MASK_BLOCK_SIZE and recalculate future_mask.
 | 
				
			||||||
 | 
					            self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
 | 
				
			||||||
 | 
					            self.register_buffer(
 | 
				
			||||||
 | 
					                "future_mask",
 | 
				
			||||||
 | 
					                baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
 | 
				
			||||||
 | 
					                persistent=False,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        mask = self.future_mask[
 | 
				
			||||||
 | 
					            : self.n_head, :seq_length_with_past, :seq_length_with_past
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					    return mask
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue