diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 2646b074..2f2b502d 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -399,6 +399,14 @@ def convert_forward(m, target_m, new_forward): convert_forward(sub_m, target_m, new_forward) +def replace_func(m, target_m, func_name, new_func): + for _, sub_m in m.named_children(): + if isinstance(sub_m, target_m): + bound_method = new_func.__get__(sub_m, sub_m.__class__) + setattr(sub_m, func_name, bound_method) + replace_func(sub_m, target_m, func_name, new_func) + + def _optimize_post(model, lightweight_bmm=False): from packaging import version from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 @@ -569,6 +577,7 @@ def _optimize_post(model, lightweight_bmm=False): from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward from bigdl.llm.transformers.models.baichuan2 import baichuan_mlp_forward + from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_get_alibi_mask convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b @@ -580,6 +589,10 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.MLP, baichuan_mlp_forward) + replace_func(model, + module.BaichuanModel, + "get_alibi_mask", + baichuan_13b_get_alibi_mask) elif model.config.model_type == "baichuan": # baichuan1 if model.config.hidden_size == 4096: diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index da45b733..c4b98a1d 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -19,19 +19,15 @@ # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py import math -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.utils.checkpoint -from torch import nn from torch.nn import functional as F -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from bigdl.llm.utils.common import invalidInputError from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache -from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu -from transformers.utils import logging, ContextManagers -from bigdl.llm.transformers.models.llama import get_ipex_version +from transformers.utils import logging logger = logging.get_logger(__name__) try: @@ -301,3 +297,94 @@ def baichuan_attention_forward_13b( attn_weights = None return attn_output, attn_weights, past_key_value + + +def _get_interleave(n): + def _get_interleave_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return _get_interleave_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + _get_interleave_power_of_2(closest_power_of_2) + + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + +def _fill_with_neg_inf(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(float("-inf")).type_as(t) + + +def _buffered_future_mask(tensor, maxpos, alibi, attn_heads): + _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1) + _future_mask = _future_mask.unsqueeze(0) + alibi + new_future_mask = _future_mask.to(tensor) + return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos] + + +def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos): + # May use fp16 for alibi mask to further reduce memory + slopes = torch.Tensor(_get_interleave(n_head)) # .half() + position_point = torch.arange(max_pos) - max_pos + 1 + position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1) + diag = torch.diag(position_point[0]) + position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2) + alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point + alibi = alibi.view(n_head, 1, max_pos) + alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1) # .half() + alibi_mask = alibi_mask.unsqueeze(0) + alibi + if tensor.device.type == "xpu": + alibi_mask = alibi_mask.to(tensor.device) + return alibi_mask + + +MASK_BLOCK_SIZE = 64 + + +def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past): + if self.training: + slopes = torch.Tensor(_get_interleave(self.n_head)) + position_point = ( + torch.arange(seq_length_with_past) - seq_length_with_past + 1 + ) + position_point = ( + position_point.unsqueeze(0) + .unsqueeze(0) + .expand(self.n_head, seq_length_with_past, -1) + ) + diag = torch.diag(position_point[0]) + position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose( + -1, -2 + ) + alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point + mask = _buffered_future_mask( + tensor, seq_length_with_past, alibi, self.n_head + ) + else: + if self.first_run: + # Override the default max_cache_pos=4096 for memory considerations + self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE + self.first_run = False + self.register_buffer( + "future_mask", + baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos), + persistent=False, + ) + if seq_length_with_past > self.max_cache_pos: + # When max_cache_pos is not enough for current sequence length, + # increase by MASK_BLOCK_SIZE and recalculate future_mask. + self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE + self.register_buffer( + "future_mask", + baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos), + persistent=False, + ) + mask = self.future_mask[ + : self.n_head, :seq_length_with_past, :seq_length_with_past + ] + return mask