Reduce max_cache_pos to reduce Baichuan2-13B memory (#9694)
* optimize baichuan2 memory * fix * style * fp16 mask * disable fp16 * fix style * empty cache * revert empty cache
This commit is contained in:
parent
361781bcd0
commit
689889482c
2 changed files with 107 additions and 7 deletions
|
|
@ -399,6 +399,14 @@ def convert_forward(m, target_m, new_forward):
|
|||
convert_forward(sub_m, target_m, new_forward)
|
||||
|
||||
|
||||
def replace_func(m, target_m, func_name, new_func):
|
||||
for _, sub_m in m.named_children():
|
||||
if isinstance(sub_m, target_m):
|
||||
bound_method = new_func.__get__(sub_m, sub_m.__class__)
|
||||
setattr(sub_m, func_name, bound_method)
|
||||
replace_func(sub_m, target_m, func_name, new_func)
|
||||
|
||||
|
||||
def _optimize_post(model, lightweight_bmm=False):
|
||||
from packaging import version
|
||||
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
||||
|
|
@ -569,6 +577,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
|
||||
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
|
||||
from bigdl.llm.transformers.models.baichuan2 import baichuan_mlp_forward
|
||||
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_get_alibi_mask
|
||||
convert_forward(model,
|
||||
module.BaichuanAttention,
|
||||
baichuan_attention_forward_13b
|
||||
|
|
@ -580,6 +589,10 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
module.MLP,
|
||||
baichuan_mlp_forward)
|
||||
replace_func(model,
|
||||
module.BaichuanModel,
|
||||
"get_alibi_mask",
|
||||
baichuan_13b_get_alibi_mask)
|
||||
elif model.config.model_type == "baichuan":
|
||||
# baichuan1
|
||||
if model.config.hidden_size == 4096:
|
||||
|
|
|
|||
|
|
@ -19,19 +19,15 @@
|
|||
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
from transformers.utils import logging, ContextManagers
|
||||
from bigdl.llm.transformers.models.llama import get_ipex_version
|
||||
from transformers.utils import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
try:
|
||||
|
|
@ -301,3 +297,94 @@ def baichuan_attention_forward_13b(
|
|||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def _get_interleave(n):
|
||||
def _get_interleave_power_of_2(n):
|
||||
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
||||
ratio = start
|
||||
return [start * ratio**i for i in range(n)]
|
||||
|
||||
if math.log2(n).is_integer():
|
||||
return _get_interleave_power_of_2(n)
|
||||
else:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
||||
return (
|
||||
_get_interleave_power_of_2(closest_power_of_2)
|
||||
+ _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
||||
)
|
||||
|
||||
|
||||
def _fill_with_neg_inf(t):
|
||||
"""FP16-compatible function that fills a tensor with -inf."""
|
||||
return t.float().fill_(float("-inf")).type_as(t)
|
||||
|
||||
|
||||
def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
|
||||
_future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
|
||||
_future_mask = _future_mask.unsqueeze(0) + alibi
|
||||
new_future_mask = _future_mask.to(tensor)
|
||||
return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]
|
||||
|
||||
|
||||
def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos):
|
||||
# May use fp16 for alibi mask to further reduce memory
|
||||
slopes = torch.Tensor(_get_interleave(n_head)) # .half()
|
||||
position_point = torch.arange(max_pos) - max_pos + 1
|
||||
position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
|
||||
diag = torch.diag(position_point[0])
|
||||
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
|
||||
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
|
||||
alibi = alibi.view(n_head, 1, max_pos)
|
||||
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1) # .half()
|
||||
alibi_mask = alibi_mask.unsqueeze(0) + alibi
|
||||
if tensor.device.type == "xpu":
|
||||
alibi_mask = alibi_mask.to(tensor.device)
|
||||
return alibi_mask
|
||||
|
||||
|
||||
MASK_BLOCK_SIZE = 64
|
||||
|
||||
|
||||
def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past):
|
||||
if self.training:
|
||||
slopes = torch.Tensor(_get_interleave(self.n_head))
|
||||
position_point = (
|
||||
torch.arange(seq_length_with_past) - seq_length_with_past + 1
|
||||
)
|
||||
position_point = (
|
||||
position_point.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
.expand(self.n_head, seq_length_with_past, -1)
|
||||
)
|
||||
diag = torch.diag(position_point[0])
|
||||
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(
|
||||
-1, -2
|
||||
)
|
||||
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
|
||||
mask = _buffered_future_mask(
|
||||
tensor, seq_length_with_past, alibi, self.n_head
|
||||
)
|
||||
else:
|
||||
if self.first_run:
|
||||
# Override the default max_cache_pos=4096 for memory considerations
|
||||
self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
|
||||
self.first_run = False
|
||||
self.register_buffer(
|
||||
"future_mask",
|
||||
baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
|
||||
persistent=False,
|
||||
)
|
||||
if seq_length_with_past > self.max_cache_pos:
|
||||
# When max_cache_pos is not enough for current sequence length,
|
||||
# increase by MASK_BLOCK_SIZE and recalculate future_mask.
|
||||
self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
|
||||
self.register_buffer(
|
||||
"future_mask",
|
||||
baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
|
||||
persistent=False,
|
||||
)
|
||||
mask = self.future_mask[
|
||||
: self.n_head, :seq_length_with_past, :seq_length_with_past
|
||||
]
|
||||
return mask
|
||||
|
|
|
|||
Loading…
Reference in a new issue