Reduce max_cache_pos to reduce Baichuan2-13B memory (#9694)

* optimize baichuan2 memory

* fix

* style

* fp16 mask

* disable fp16

* fix style

* empty cache

* revert empty cache
This commit is contained in:
Kai Huang 2023-12-26 19:51:25 +08:00 committed by GitHub
parent 361781bcd0
commit 689889482c
2 changed files with 107 additions and 7 deletions

View file

@ -399,6 +399,14 @@ def convert_forward(m, target_m, new_forward):
convert_forward(sub_m, target_m, new_forward)
def replace_func(m, target_m, func_name, new_func):
for _, sub_m in m.named_children():
if isinstance(sub_m, target_m):
bound_method = new_func.__get__(sub_m, sub_m.__class__)
setattr(sub_m, func_name, bound_method)
replace_func(sub_m, target_m, func_name, new_func)
def _optimize_post(model, lightweight_bmm=False):
from packaging import version
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
@ -569,6 +577,7 @@ def _optimize_post(model, lightweight_bmm=False):
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
from bigdl.llm.transformers.models.baichuan2 import baichuan_mlp_forward
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_get_alibi_mask
convert_forward(model,
module.BaichuanAttention,
baichuan_attention_forward_13b
@ -580,6 +589,10 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.MLP,
baichuan_mlp_forward)
replace_func(model,
module.BaichuanModel,
"get_alibi_mask",
baichuan_13b_get_alibi_mask)
elif model.config.model_type == "baichuan":
# baichuan1
if model.config.hidden_size == 4096:

View file

@ -19,19 +19,15 @@
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
import math
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from transformers.utils import logging, ContextManagers
from bigdl.llm.transformers.models.llama import get_ipex_version
from transformers.utils import logging
logger = logging.get_logger(__name__)
try:
@ -301,3 +297,94 @@ def baichuan_attention_forward_13b(
attn_weights = None
return attn_output, attn_weights, past_key_value
def _get_interleave(n):
def _get_interleave_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return _get_interleave_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
_get_interleave_power_of_2(closest_power_of_2)
+ _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
def _fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float("-inf")).type_as(t)
def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
_future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
_future_mask = _future_mask.unsqueeze(0) + alibi
new_future_mask = _future_mask.to(tensor)
return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]
def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos):
# May use fp16 for alibi mask to further reduce memory
slopes = torch.Tensor(_get_interleave(n_head)) # .half()
position_point = torch.arange(max_pos) - max_pos + 1
position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
diag = torch.diag(position_point[0])
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
alibi = alibi.view(n_head, 1, max_pos)
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1) # .half()
alibi_mask = alibi_mask.unsqueeze(0) + alibi
if tensor.device.type == "xpu":
alibi_mask = alibi_mask.to(tensor.device)
return alibi_mask
MASK_BLOCK_SIZE = 64
def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past):
if self.training:
slopes = torch.Tensor(_get_interleave(self.n_head))
position_point = (
torch.arange(seq_length_with_past) - seq_length_with_past + 1
)
position_point = (
position_point.unsqueeze(0)
.unsqueeze(0)
.expand(self.n_head, seq_length_with_past, -1)
)
diag = torch.diag(position_point[0])
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(
-1, -2
)
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
mask = _buffered_future_mask(
tensor, seq_length_with_past, alibi, self.n_head
)
else:
if self.first_run:
# Override the default max_cache_pos=4096 for memory considerations
self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
self.first_run = False
self.register_buffer(
"future_mask",
baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
persistent=False,
)
if seq_length_with_past > self.max_cache_pos:
# When max_cache_pos is not enough for current sequence length,
# increase by MASK_BLOCK_SIZE and recalculate future_mask.
self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
self.register_buffer(
"future_mask",
baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
persistent=False,
)
mask = self.future_mask[
: self.n_head, :seq_length_with_past, :seq_length_with_past
]
return mask