Reduce max_cache_pos to reduce Baichuan2-13B memory (#9694)

* optimize baichuan2 memory

* fix

* style

* fp16 mask

* disable fp16

* fix style

* empty cache

* revert empty cache
This commit is contained in:
Kai Huang 2023-12-26 19:51:25 +08:00 committed by GitHub
parent 361781bcd0
commit 689889482c
2 changed files with 107 additions and 7 deletions

View file

@ -399,6 +399,14 @@ def convert_forward(m, target_m, new_forward):
convert_forward(sub_m, target_m, new_forward) convert_forward(sub_m, target_m, new_forward)
def replace_func(m, target_m, func_name, new_func):
for _, sub_m in m.named_children():
if isinstance(sub_m, target_m):
bound_method = new_func.__get__(sub_m, sub_m.__class__)
setattr(sub_m, func_name, bound_method)
replace_func(sub_m, target_m, func_name, new_func)
def _optimize_post(model, lightweight_bmm=False): def _optimize_post(model, lightweight_bmm=False):
from packaging import version from packaging import version
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
@ -569,6 +577,7 @@ def _optimize_post(model, lightweight_bmm=False):
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
from bigdl.llm.transformers.models.baichuan2 import baichuan_mlp_forward from bigdl.llm.transformers.models.baichuan2 import baichuan_mlp_forward
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_get_alibi_mask
convert_forward(model, convert_forward(model,
module.BaichuanAttention, module.BaichuanAttention,
baichuan_attention_forward_13b baichuan_attention_forward_13b
@ -580,6 +589,10 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, convert_forward(model,
module.MLP, module.MLP,
baichuan_mlp_forward) baichuan_mlp_forward)
replace_func(model,
module.BaichuanModel,
"get_alibi_mask",
baichuan_13b_get_alibi_mask)
elif model.config.model_type == "baichuan": elif model.config.model_type == "baichuan":
# baichuan1 # baichuan1
if model.config.hidden_size == 4096: if model.config.hidden_size == 4096:

View file

@ -19,19 +19,15 @@
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
import math import math
from typing import List, Optional, Tuple, Union from typing import Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from transformers.utils import logging, ContextManagers from transformers.utils import logging
from bigdl.llm.transformers.models.llama import get_ipex_version
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
try: try:
@ -301,3 +297,94 @@ def baichuan_attention_forward_13b(
attn_weights = None attn_weights = None
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def _get_interleave(n):
def _get_interleave_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return _get_interleave_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
_get_interleave_power_of_2(closest_power_of_2)
+ _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
def _fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float("-inf")).type_as(t)
def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
_future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
_future_mask = _future_mask.unsqueeze(0) + alibi
new_future_mask = _future_mask.to(tensor)
return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]
def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos):
# May use fp16 for alibi mask to further reduce memory
slopes = torch.Tensor(_get_interleave(n_head)) # .half()
position_point = torch.arange(max_pos) - max_pos + 1
position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
diag = torch.diag(position_point[0])
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
alibi = alibi.view(n_head, 1, max_pos)
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1) # .half()
alibi_mask = alibi_mask.unsqueeze(0) + alibi
if tensor.device.type == "xpu":
alibi_mask = alibi_mask.to(tensor.device)
return alibi_mask
MASK_BLOCK_SIZE = 64
def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past):
if self.training:
slopes = torch.Tensor(_get_interleave(self.n_head))
position_point = (
torch.arange(seq_length_with_past) - seq_length_with_past + 1
)
position_point = (
position_point.unsqueeze(0)
.unsqueeze(0)
.expand(self.n_head, seq_length_with_past, -1)
)
diag = torch.diag(position_point[0])
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(
-1, -2
)
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
mask = _buffered_future_mask(
tensor, seq_length_with_past, alibi, self.n_head
)
else:
if self.first_run:
# Override the default max_cache_pos=4096 for memory considerations
self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
self.first_run = False
self.register_buffer(
"future_mask",
baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
persistent=False,
)
if seq_length_with_past > self.max_cache_pos:
# When max_cache_pos is not enough for current sequence length,
# increase by MASK_BLOCK_SIZE and recalculate future_mask.
self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
self.register_buffer(
"future_mask",
baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
persistent=False,
)
mask = self.future_mask[
: self.n_head, :seq_length_with_past, :seq_length_with_past
]
return mask