Reduce max_cache_pos to reduce Baichuan2-13B memory (#9694)
* optimize baichuan2 memory * fix * style * fp16 mask * disable fp16 * fix style * empty cache * revert empty cache
This commit is contained in:
parent
361781bcd0
commit
689889482c
2 changed files with 107 additions and 7 deletions
|
|
@ -399,6 +399,14 @@ def convert_forward(m, target_m, new_forward):
|
||||||
convert_forward(sub_m, target_m, new_forward)
|
convert_forward(sub_m, target_m, new_forward)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_func(m, target_m, func_name, new_func):
|
||||||
|
for _, sub_m in m.named_children():
|
||||||
|
if isinstance(sub_m, target_m):
|
||||||
|
bound_method = new_func.__get__(sub_m, sub_m.__class__)
|
||||||
|
setattr(sub_m, func_name, bound_method)
|
||||||
|
replace_func(sub_m, target_m, func_name, new_func)
|
||||||
|
|
||||||
|
|
||||||
def _optimize_post(model, lightweight_bmm=False):
|
def _optimize_post(model, lightweight_bmm=False):
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
||||||
|
|
@ -569,6 +577,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
|
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
|
||||||
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
|
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
|
||||||
from bigdl.llm.transformers.models.baichuan2 import baichuan_mlp_forward
|
from bigdl.llm.transformers.models.baichuan2 import baichuan_mlp_forward
|
||||||
|
from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_get_alibi_mask
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.BaichuanAttention,
|
module.BaichuanAttention,
|
||||||
baichuan_attention_forward_13b
|
baichuan_attention_forward_13b
|
||||||
|
|
@ -580,6 +589,10 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.MLP,
|
module.MLP,
|
||||||
baichuan_mlp_forward)
|
baichuan_mlp_forward)
|
||||||
|
replace_func(model,
|
||||||
|
module.BaichuanModel,
|
||||||
|
"get_alibi_mask",
|
||||||
|
baichuan_13b_get_alibi_mask)
|
||||||
elif model.config.model_type == "baichuan":
|
elif model.config.model_type == "baichuan":
|
||||||
# baichuan1
|
# baichuan1
|
||||||
if model.config.hidden_size == 4096:
|
if model.config.hidden_size == 4096:
|
||||||
|
|
|
||||||
|
|
@ -19,19 +19,15 @@
|
||||||
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
|
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Optional, Tuple
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
from transformers.utils import logging, ContextManagers
|
from transformers.utils import logging
|
||||||
from bigdl.llm.transformers.models.llama import get_ipex_version
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -301,3 +297,94 @@ def baichuan_attention_forward_13b(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def _get_interleave(n):
|
||||||
|
def _get_interleave_power_of_2(n):
|
||||||
|
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
||||||
|
ratio = start
|
||||||
|
return [start * ratio**i for i in range(n)]
|
||||||
|
|
||||||
|
if math.log2(n).is_integer():
|
||||||
|
return _get_interleave_power_of_2(n)
|
||||||
|
else:
|
||||||
|
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
||||||
|
return (
|
||||||
|
_get_interleave_power_of_2(closest_power_of_2)
|
||||||
|
+ _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fill_with_neg_inf(t):
|
||||||
|
"""FP16-compatible function that fills a tensor with -inf."""
|
||||||
|
return t.float().fill_(float("-inf")).type_as(t)
|
||||||
|
|
||||||
|
|
||||||
|
def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
|
||||||
|
_future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
|
||||||
|
_future_mask = _future_mask.unsqueeze(0) + alibi
|
||||||
|
new_future_mask = _future_mask.to(tensor)
|
||||||
|
return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]
|
||||||
|
|
||||||
|
|
||||||
|
def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos):
|
||||||
|
# May use fp16 for alibi mask to further reduce memory
|
||||||
|
slopes = torch.Tensor(_get_interleave(n_head)) # .half()
|
||||||
|
position_point = torch.arange(max_pos) - max_pos + 1
|
||||||
|
position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
|
||||||
|
diag = torch.diag(position_point[0])
|
||||||
|
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
|
||||||
|
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
|
||||||
|
alibi = alibi.view(n_head, 1, max_pos)
|
||||||
|
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1) # .half()
|
||||||
|
alibi_mask = alibi_mask.unsqueeze(0) + alibi
|
||||||
|
if tensor.device.type == "xpu":
|
||||||
|
alibi_mask = alibi_mask.to(tensor.device)
|
||||||
|
return alibi_mask
|
||||||
|
|
||||||
|
|
||||||
|
MASK_BLOCK_SIZE = 64
|
||||||
|
|
||||||
|
|
||||||
|
def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past):
|
||||||
|
if self.training:
|
||||||
|
slopes = torch.Tensor(_get_interleave(self.n_head))
|
||||||
|
position_point = (
|
||||||
|
torch.arange(seq_length_with_past) - seq_length_with_past + 1
|
||||||
|
)
|
||||||
|
position_point = (
|
||||||
|
position_point.unsqueeze(0)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(self.n_head, seq_length_with_past, -1)
|
||||||
|
)
|
||||||
|
diag = torch.diag(position_point[0])
|
||||||
|
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(
|
||||||
|
-1, -2
|
||||||
|
)
|
||||||
|
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
|
||||||
|
mask = _buffered_future_mask(
|
||||||
|
tensor, seq_length_with_past, alibi, self.n_head
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.first_run:
|
||||||
|
# Override the default max_cache_pos=4096 for memory considerations
|
||||||
|
self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
|
||||||
|
self.first_run = False
|
||||||
|
self.register_buffer(
|
||||||
|
"future_mask",
|
||||||
|
baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
if seq_length_with_past > self.max_cache_pos:
|
||||||
|
# When max_cache_pos is not enough for current sequence length,
|
||||||
|
# increase by MASK_BLOCK_SIZE and recalculate future_mask.
|
||||||
|
self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
|
||||||
|
self.register_buffer(
|
||||||
|
"future_mask",
|
||||||
|
baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
mask = self.future_mask[
|
||||||
|
: self.n_head, :seq_length_with_past, :seq_length_with_past
|
||||||
|
]
|
||||||
|
return mask
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue