diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 3ad0bcd7..29e7c7cd 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -680,6 +680,11 @@ def _optimize_pre(model): if model.lm_head.weight.data.device != "meta": norm_weight = nn.functional.normalize(lm_head_weight_data) model.lm_head.weight.data = norm_weight + + # for baichuan2-7B + if model.config.hidden_size in [4096, 2048]: + from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq + model.apply(pre_compute_inv_freq) # for yuan 2.0 if model.config.model_type == "yuan": from ipex_llm.transformers.models.yuan import merge_qk @@ -703,12 +708,6 @@ def _optimize_pre(model): model.apply(pre_compute_inv_freq) from ipex_llm.transformers.models.phi3 import split_mlp model.apply(split_mlp) - # for baichuan2 - if model.config.model_type == "baichuan" and model.config.vocab_size == 125696: - if model.config.hidden_size in [4096, 2048]: - # baichuan2-7B - from ipex_llm.transformers.models.baichuan2 import pre_compute_inv_freq - model.apply(pre_compute_inv_freq) # for qwen2 if model.config.model_type == "qwen2": from ipex_llm.transformers.models.qwen2 import merge_qkv @@ -1125,84 +1124,39 @@ def _optimize_post(model, lightweight_bmm=False): module.FalconAttention, falcon_attention_forward ) - - elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696: - # baichuan2 - if model.config.hidden_size in [4096, 2048]: - # baichuan2-7B - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.baichuan2 import baichuan_attention_forward_7b - from ipex_llm.transformers.models.baichuan2 import baichuan_mlp_forward - convert_forward(model, - module.Attention, - baichuan_attention_forward_7b - ) - convert_forward(model, - module.RMSNorm, - llama_rms_norm_forward) - convert_forward(model, - module.MLP, - baichuan_mlp_forward) - elif model.config.hidden_size == 5120: - # baichuan2-13B - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.baichuan2 import baichuan_attention_forward_13b - from ipex_llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward - from ipex_llm.transformers.models.baichuan2 import baichuan_mlp_forward - from ipex_llm.transformers.models.baichuan2 import baichuan_13b_get_alibi_mask - convert_forward(model, - module.BaichuanAttention, - baichuan_attention_forward_13b - ) - # baichuan2-13B's RMSNorm is a little different - convert_forward(model, - module.RMSNorm, - baichuan_13b_rms_norm_forward) - convert_forward(model, - module.MLP, - baichuan_mlp_forward) - if hasattr(model.model, 'get_alibi_mask_orig'): - # deepspeed rewrite "get_alibi_mask" to support baichuan - # https://github.com/microsoft/DeepSpeed/pull/4721 - replace_func(model, - module.BaichuanModel, - "get_alibi_mask_orig", - baichuan_13b_get_alibi_mask) - else: - replace_func(model, - module.BaichuanModel, - "get_alibi_mask", - baichuan_13b_get_alibi_mask) elif model.config.model_type == "baichuan": - # baichuan1 - if model.config.hidden_size == 4096: - # baichuan-7B - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from ipex_llm.transformers.models.baichuan import baichuan_mlp_forward + convert_forward(model, module.MLP, baichuan_mlp_forward) + + if model.config.hidden_size in [4096, 2048]: + # baichuan-7B and baichuan2-7B from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_7b - convert_forward(model, - module.Attention, - baichuan_attention_forward_7b - ) - convert_forward(model, - module.RMSNorm, - llama_rms_norm_forward) + convert_forward(model, module.Attention, baichuan_attention_forward_7b) + convert_forward(model, module.RMSNorm, llama_rms_norm_forward) elif model.config.hidden_size == 5120: - # baichuan-13B - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) + # baichuan-13B and baichuan2-13B from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b - from ipex_llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward - convert_forward(model, - module.BaichuanAttention, - baichuan_attention_forward_13b - ) - # baichuan-13B's RMSNorm is a little different - convert_forward(model, - module.RMSNorm, - baichuan_13b_rms_norm_forward) + from ipex_llm.transformers.models.baichuan import baichuan_13b_rms_norm_forward + convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b) + convert_forward(model, module.RMSNorm, baichuan_13b_rms_norm_forward) + + if model.config.vocab_size == 125696: + # baichaun2-13B + from ipex_llm.transformers.models.baichuan import baichuan_13b_get_alibi_mask + if hasattr(model.model, 'get_alibi_mask_orig'): + # deepspeed rewrite "get_alibi_mask" to support baichuan + # https://github.com/microsoft/DeepSpeed/pull/4721 + replace_func(model, + module.BaichuanModel, + "get_alibi_mask_orig", + baichuan_13b_get_alibi_mask) + else: + replace_func(model, + module.BaichuanModel, + "get_alibi_mask", + baichuan_13b_get_alibi_mask) elif model.config.model_type == "gpt_neox": from ipex_llm.transformers.models.gptneox import gptneox_attention_forward convert_forward(model, diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index 8bcdb637..c74e9754 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -14,30 +14,61 @@ # limitations under the License. # This file is adapted from -# https://huggingface.co/baichuan-inc/Baichuan-7B/blob/c1a5c7d5b7f50ecc51bb0e08150a9f12e5656756/modeling_baichuan.py +# https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/cb7fc748b78b7ea99772e4cf76db155729ce774e/modeling_baichuan.py # and -# https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/a4a558127068f2ce965aa56aeb826bf501a68970/modeling_baichuan.py - +# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py import math -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.utils.checkpoint -from torch import nn -import torch.nn.functional as F -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ipex_llm.utils.common import invalidInputError -from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp -from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ - append_kv_cache, is_enough_kv_cache_room_4_31 -from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ - restore_fp8_kv_cache, use_quantize_kv_cache -from ipex_llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb -from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu +from torch.nn import functional as F +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache +from ipex_llm.transformers.models.utils import update_past_key_value +from ipex_llm.transformers.models.utils import should_use_fuse_rope +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal +from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU +from ipex_llm.transformers.models.utils import mlp_fusion_check +import warnings -import os -KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) +def pre_compute_inv_freq(module: torch.nn.Module): + if module.__class__.__name__ == "RotaryEmbedding": + inv_freq = module.inv_freq + del module.inv_freq + module.register_buffer("inv_freq", inv_freq, persistent=False) + + +def baichuan_13b_rms_norm_forward(self, hidden_states): + if hidden_states.device.type == "xpu" and not (self.training or hidden_states.requires_grad): + import xe_addons + x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() + output = xe_addons.rms_norm(self.weight, x_2d, self.epsilon) + return output.reshape(hidden_states.shape) + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) + return self.weight * hidden_states.to(input_dtype) + + +def baichuan_mlp_forward( + self, + x: torch.Tensor, +) -> torch.Tensor: + x_2d = x.view(-1, x.shape[-1]) + qtype = getattr(self.gate_proj, "qtype", None) + if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla: + import xe_linear + if not x_2d.is_contiguous(): + x_2d = x_2d.contiguous() + return self.down_proj(xe_linear.mlp_forward_xpu( + x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, + x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, + SILU, qtype + )) + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) def baichuan_attention_forward_7b( @@ -48,269 +79,82 @@ def baichuan_attention_forward_7b( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if use_quantize_kv_cache(self.W_pack, hidden_states): - forward_function = baichuan_attention_forward_7b_quantized - else: - forward_function = baichuan_attention_forward_7b_origin - return forward_function( - self=self, - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache - ) - - -def baichuan_attention_forward_7b_quantized( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +): bsz, q_len, _ = hidden_states.size() device = hidden_states.device - proj = self.W_pack(hidden_states) - proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) - # batch_size x source_len x hidden_size - query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - # batch_size x target_len x head_size - key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - # batch_size x source_len x hidden_size - value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + qkv = self.W_pack(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_heads, + self.num_heads], dim=1) - kv_seq_len = key_states.shape[-2] + kv_seq_len = key_states.shape[2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "baichuan") + kv_seq_len += past_key_value[0].shape[2] + + # IPEX-LLM OPT: fuse rope + if should_use_fuse_rope(hidden_states, position_ids, self.training): + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids, "baichuan") - # [bsz, nh, t, hd] + query_states = query_states.to(hidden_states.dtype) + key_states = key_states.to(hidden_states.dtype) - if past_key_value is None: - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}" - f", but is {attn_weights.size()}" - ) - - if attention_mask is not None: - invalidInputError( - attention_mask.size() == (bsz, 1, q_len, kv_seq_len), - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " - f"but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, - torch.tensor(torch.finfo(attn_weights.dtype).min)) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - kv_seq_len = key_states.shape[-2] - if use_cache: - k_cache, v_cache = init_fp8_kv_cache( - bsz, self.num_heads, kv_seq_len, self.head_dim, - device=device - ) - key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, - value_states) - past_key_value = (key_states, value_states) - else: - k_cache, v_cache = past_key_value - key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, - key_states, value_states) - kv_seq_len = key_states.shape[-2] - past_key_value = (key_states, value_states) - if query_states.size(2) != 1 or query_states.device.type != 'xpu': - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) - attn_weights = attn_weights / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}" - f", but is {attn_weights.size()}" - ) - - if attention_mask is not None: - invalidInputError( - attention_mask.size() == (bsz, 1, q_len, kv_seq_len), - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " - f"but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, - torch.tensor(torch.finfo(attn_weights.dtype).min)) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - else: - import xe_addons - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - attn_weights = None - - invalidInputError( - attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), - f"`attn_output` should be of size " - f"{(bsz, self.num_heads, q_len, self.head_dim)}," - f"but is {attn_output.size()}" + # IPEX-LLM OPT: kv cache and quantize kv + use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states) + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, device ) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output.to(hidden_states.dtype), attn_weights, past_key_value - - -def baichuan_attention_forward_7b_origin( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - device = hidden_states.device - - proj = self.W_pack(hidden_states) - proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) - # batch_size x source_len x hidden_size - query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - # batch_size x target_len x head_size - key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - # batch_size x source_len x hidden_size - value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - enough_kv_room = True - if past_key_value is not None: - enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len) - kv_seq_len += past_key_value[0].shape[-2] - if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "baichuan") - else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "baichuan") - # [bsz, nh, t, hd] - - # if past_key_value is not None: - # # reuse k, v, self_attention - # key_states = torch.cat([past_key_value[0], key_states], dim=2) - # value_states = torch.cat([past_key_value[1], value_states], dim=2) - if past_key_value is not None: - # reuse k, v, self_attention - cache_k = past_key_value[0] - cache_v = past_key_value[1] - if not enough_kv_room: - # allocate new - new_cache_k, new_cache_v = extend_kv_cache(bsz, - self.num_heads, - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) - new_cache_k[:] = cache_k - new_cache_v[:] = cache_v - cache_k = new_cache_k - cache_v = new_cache_v - - key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) - - elif use_cache: - max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = init_kv_cache(bsz, - self.num_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key_states.dtype, - device=device) - new_key_states[:] = key_states - new_value_states[:] = value_states - key_states = new_key_states - value_states = new_value_states - past_key_value = (key_states, value_states) if use_cache else None + if self.training: + warnings.warn("xops is not supported on Intel GPU, so just use normal implementation") + + # IPEX-LLM OPT: sdp + attn_weights = None if not self.training and not hidden_states.requires_grad and \ use_flash_attention(query_states, key_states, attention_mask): - attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16), - key_states.to(device, dtype=torch.float16), - value_states.to(device, dtype=torch.float16), - is_causal=True) - attn_weights = None - elif not self.training and not hidden_states.requires_grad and \ - use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16), + key_states.to(dtype=torch.float16), + value_states.to(dtype=torch.float16), + is_causal=True).to(hidden_states.dtype) + elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states): import xe_addons - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) - attn_output = attn_output.view(query_states.shape) - attn_weights = None + if use_quantize_kv: + attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, + attention_mask) + else: + attn_output = xe_addons.sdp(query_states, key_states, value_states, + attention_mask) + elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): + import xe_addons + if use_quantize_kv: + attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, + value_states, attention_mask) + else: + attn_output = xe_addons.sdp_causal(query_states, key_states, + value_states, attention_mask) else: + if use_quantize_kv: + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - invalidInputError(False, - f"Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}" - f", but is {attn_weights.size()}") - if attention_mask is not None: - invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " - f"but is {attention_mask.size()}") attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, - torch.tensor(torch.finfo(attn_weights.dtype).min)) - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(value_states.dtype) attn_output = torch.matmul(attn_weights, value_states) - invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), - f"`attn_output` should be of size " - f"{(bsz, self.num_heads, q_len, self.head_dim)}," - f"but is {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -318,7 +162,7 @@ def baichuan_attention_forward_7b_origin( if not output_attentions: attn_weights = None - return attn_output.to(hidden_states.dtype), attn_weights, past_key_value + return attn_output, attn_weights, past_key_value def baichuan_attention_forward_13b( @@ -329,101 +173,57 @@ def baichuan_attention_forward_13b( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if use_quantize_kv_cache(self.W_pack, hidden_states): - forward_function = baichuan_attention_forward_13b_quantized - else: - forward_function = baichuan_attention_forward_13b_origin - return forward_function( - self=self, - hidden_states=hidden_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache - ) - - -def baichuan_attention_forward_13b_quantized( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() device = hidden_states.device - proj = self.W_pack(hidden_states) - proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) - query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + qkv = self.W_pack(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_heads, + self.num_heads], dim=1) - kv_seq_len = key_states.shape[-2] + kv_seq_len = key_states.shape[2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + kv_seq_len += past_key_value[0].shape[2] - if past_key_value is None: + # IPEX-LLM OPT: kv cache and quantize kv + use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states) + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, device + ) + past_key_value = (key_states, value_states) if use_cache else None + + if self.training: + warnings.warn("xops is not supported on Intel GPU, so just use normal implementation") + + if attention_mask is not None: + if len(attention_mask.size()) == 4: + attention_mask = attention_mask[:, :, -q_len:, :] + else: + attention_mask = attention_mask[None, :, -q_len:, :] + + if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): + import xe_addons + if use_quantize_kv: + attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, + attention_mask) + else: + attn_output = xe_addons.sdp(query_states, key_states, value_states, + attention_mask) + attn_weights = None + else: + if use_quantize_kv: + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: - if q_len == 1: # inference with cache - if len(attention_mask.size()) == 4: - attention_mask = attention_mask[:, :, -1:, :] - else: - attention_mask = attention_mask[:, -1:, :] attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, - torch.tensor(torch.finfo(attn_weights.dtype).min)) - + attn_weights = attn_weights.to(query_states.dtype) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - - attn_output = torch.matmul(attn_weights, value_states) - kv_seq_len = key_states.shape[-2] - if use_cache: - k_cache, v_cache = init_fp8_kv_cache( - bsz, self.num_heads, kv_seq_len, self.head_dim, - device=device - ) - key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, - key_states, value_states) - past_key_value = (key_states, value_states) - else: - k_cache, v_cache = past_key_value - key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, - key_states, value_states) - kv_seq_len = key_states.shape[-2] - past_key_value = (key_states, value_states) - if query_states.size(2) != 1 or query_states.device.type != 'xpu': - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) - else: - import xe_addons - attn_weights = xe_addons.query_key_fp8_matmul(query_states, key_states) - - attn_weights = attn_weights / math.sqrt(self.head_dim) - - if attention_mask is not None: - if q_len == 1: # inference with cache - if len(attention_mask.size()) == 4: - attention_mask = attention_mask[:, :, -1:, :] - else: - attention_mask = attention_mask[:, -1:, :] - attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, - torch.tensor(torch.finfo(attn_weights.dtype).min)) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - if query_states.size(2) != 1 or query_states.device.type != 'xpu': - attn_output = torch.matmul(attn_weights, value_states) - else: - import xe_addons - attn_output = xe_addons.attn_value_fp8_matmul(attn_weights, - value_states) - + attn_output = torch.matmul(attn_weights.to(dtype=value_states.dtype), value_states) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -434,90 +234,92 @@ def baichuan_attention_forward_13b_quantized( return attn_output, attn_weights, past_key_value -def baichuan_attention_forward_13b_origin( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +def _get_interleave(n): + def _get_interleave_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] - bsz, q_len, _ = hidden_states.size() - device = hidden_states.device + if math.log2(n).is_integer(): + return _get_interleave_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + _get_interleave_power_of_2(closest_power_of_2) + + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) - proj = self.W_pack(hidden_states) - proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) - query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - enough_kv_room = True - if past_key_value is not None: - enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len) - kv_seq_len += past_key_value[0].shape[-2] +def _fill_with_neg_inf(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(float("-inf")).type_as(t) - # if past_key_value is not None: - # # reuse k, v, self_attention - # key_states = torch.cat([past_key_value[0], key_states], dim=2) - # value_states = torch.cat([past_key_value[1], value_states], dim=2) - if past_key_value is not None: - # reuse k, v, self_attention - cache_k = past_key_value[0] - cache_v = past_key_value[1] - if not enough_kv_room: - # allocate new - new_cache_k, new_cache_v = extend_kv_cache(bsz, - self.num_heads, - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) - new_cache_k[:] = cache_k - new_cache_v[:] = cache_v - cache_k = new_cache_k - cache_v = new_cache_v - key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) +def _buffered_future_mask(tensor, maxpos, alibi, attn_heads): + _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1) + _future_mask = _future_mask.unsqueeze(0) + alibi + new_future_mask = _future_mask.to(tensor) + return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos] - elif use_cache: - max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = init_kv_cache(bsz, - self.num_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key_states.dtype, - device=device) - new_key_states[:] = key_states - new_value_states[:] = value_states - key_states = new_key_states - value_states = new_value_states - past_key_value = (key_states, value_states) if use_cache else None +def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos): + slopes = torch.Tensor(_get_interleave(n_head)).to(tensor.dtype) + position_point = torch.arange(max_pos) - max_pos + 1 + position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1) + diag = torch.diag(position_point[0]) + position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2) + alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point + alibi = alibi.view(n_head, 1, max_pos) + alibi_mask = torch.triu( + _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1).to(tensor.dtype) + alibi_mask = alibi_mask.unsqueeze(0) + alibi + if tensor.device.type == "xpu": + alibi_mask = alibi_mask.to(tensor.device) + return alibi_mask - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - if q_len == 1: # inference with cache - if len(attention_mask.size()) == 4: - attention_mask = attention_mask[:, :, -1:, :] - else: - attention_mask = attention_mask[:, -1:, :] - attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) +MASK_BLOCK_SIZE = 512 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output.to(hidden_states.dtype), attn_weights, past_key_value +def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past): + if self.training: + slopes = torch.Tensor(_get_interleave(self.n_head)) + position_point = ( + torch.arange(seq_length_with_past) - seq_length_with_past + 1 + ) + position_point = ( + position_point.unsqueeze(0) + .unsqueeze(0) + .expand(self.n_head, seq_length_with_past, -1) + ) + diag = torch.diag(position_point[0]) + position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose( + -1, -2 + ) + alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point + mask = _buffered_future_mask( + tensor, seq_length_with_past, alibi, self.n_head + ) + else: + if self.first_run: + # Override the default max_cache_pos=4096 for memory considerations + self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE + self.first_run = False + self.register_buffer( + "future_mask", + baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos), + persistent=False, + ) + if seq_length_with_past > self.max_cache_pos: + # When max_cache_pos is not enough for current sequence length, + # increase by MASK_BLOCK_SIZE and recalculate future_mask. + self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE + self.register_buffer( + "future_mask", + baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos), + persistent=False, + ) + mask = self.future_mask[ + : self.n_head, :seq_length_with_past, :seq_length_with_past + ] + return mask diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan2.py b/python/llm/src/ipex_llm/transformers/models/baichuan2.py deleted file mode 100644 index 4b450b5f..00000000 --- a/python/llm/src/ipex_llm/transformers/models/baichuan2.py +++ /dev/null @@ -1,325 +0,0 @@ -# -# Copyright 2016 The BigDL Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This file is adapted from -# https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/cb7fc748b78b7ea99772e4cf76db155729ce774e/modeling_baichuan.py -# and -# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py - -import math -from typing import Optional, Tuple -import torch -import torch.utils.checkpoint -from torch.nn import functional as F -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache -from ipex_llm.transformers.models.utils import update_past_key_value -from ipex_llm.transformers.models.utils import should_use_fuse_rope -from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal -from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU -from ipex_llm.transformers.models.utils import mlp_fusion_check -import warnings - - -def pre_compute_inv_freq(module: torch.nn.Module): - if module.__class__.__name__ == "RotaryEmbedding": - inv_freq = module.inv_freq - del module.inv_freq - module.register_buffer("inv_freq", inv_freq, persistent=False) - - -def baichuan_13b_rms_norm_forward(self, hidden_states): - if hidden_states.device.type == "xpu" and not (self.training or hidden_states.requires_grad): - import xe_addons - x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() - output = xe_addons.rms_norm(self.weight, x_2d, self.epsilon) - return output.reshape(hidden_states.shape) - - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) - return self.weight * hidden_states.to(input_dtype) - - -def baichuan_mlp_forward( - self, - x: torch.Tensor, -) -> torch.Tensor: - x_2d = x.view(-1, x.shape[-1]) - qtype = getattr(self.gate_proj, "qtype", None) - if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla: - import xe_linear - if not x_2d.is_contiguous(): - x_2d = x_2d.contiguous() - return self.down_proj(xe_linear.mlp_forward_xpu( - x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, - x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, - SILU, qtype - )) - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -def baichuan_attention_forward_7b( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, -): - bsz, q_len, _ = hidden_states.size() - device = hidden_states.device - - qkv = self.W_pack(hidden_states) - qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim) - qkv = qkv.transpose(1, 2) - query_states, key_states, value_states = qkv.split([self.num_heads, - self.num_heads, - self.num_heads], dim=1) - - kv_seq_len = key_states.shape[2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[2] - - # IPEX-LLM OPT: fuse rope - if should_use_fuse_rope(hidden_states, position_ids, self.training): - import xe_addons - xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, - query_states, key_states) - else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "baichuan") - query_states = query_states.to(hidden_states.dtype) - key_states = key_states.to(hidden_states.dtype) - - # IPEX-LLM OPT: kv cache and quantize kv - use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states) - key_states, value_states = update_past_key_value( - past_key_value, key_states, value_states, - kv_seq_len, use_quantize_kv, device - ) - past_key_value = (key_states, value_states) if use_cache else None - - if self.training: - warnings.warn("xops is not supported on Intel GPU, so just use normal implementation") - - # IPEX-LLM OPT: sdp - attn_weights = None - if not self.training and not hidden_states.requires_grad and \ - use_flash_attention(query_states, key_states, attention_mask): - attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16), - key_states.to(dtype=torch.float16), - value_states.to(dtype=torch.float16), - is_causal=True).to(hidden_states.dtype) - elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, attention_mask) - else: - if use_quantize_kv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def baichuan_attention_forward_13b( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - device = hidden_states.device - - qkv = self.W_pack(hidden_states) - qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim) - qkv = qkv.transpose(1, 2) - query_states, key_states, value_states = qkv.split([self.num_heads, - self.num_heads, - self.num_heads], dim=1) - - kv_seq_len = key_states.shape[2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[2] - - # IPEX-LLM OPT: kv cache and quantize kv - use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states) - key_states, value_states = update_past_key_value( - past_key_value, key_states, value_states, - kv_seq_len, use_quantize_kv, device - ) - past_key_value = (key_states, value_states) if use_cache else None - - if self.training: - warnings.warn("xops is not supported on Intel GPU, so just use normal implementation") - - if attention_mask is not None: - if len(attention_mask.size()) == 4: - attention_mask = attention_mask[:, :, -q_len:, :] - else: - attention_mask = attention_mask[:, None, -q_len:, :] - - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) - attn_weights = None - else: - if use_quantize_kv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - attn_weights = attn_weights.to(query_states.dtype) - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.matmul(attn_weights.to(dtype=value_states.dtype), value_states) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def _get_interleave(n): - def _get_interleave_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return _get_interleave_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - _get_interleave_power_of_2(closest_power_of_2) - + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] - ) - - -def _fill_with_neg_inf(t): - """FP16-compatible function that fills a tensor with -inf.""" - return t.float().fill_(float("-inf")).type_as(t) - - -def _buffered_future_mask(tensor, maxpos, alibi, attn_heads): - _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1) - _future_mask = _future_mask.unsqueeze(0) + alibi - new_future_mask = _future_mask.to(tensor) - return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos] - - -def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos): - slopes = torch.Tensor(_get_interleave(n_head)).to(tensor.dtype) - position_point = torch.arange(max_pos) - max_pos + 1 - position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1) - diag = torch.diag(position_point[0]) - position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2) - alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point - alibi = alibi.view(n_head, 1, max_pos) - alibi_mask = torch.triu( - _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1).to(tensor.dtype) - alibi_mask = alibi_mask.unsqueeze(0) + alibi - if tensor.device.type == "xpu": - alibi_mask = alibi_mask.to(tensor.device) - return alibi_mask - - -MASK_BLOCK_SIZE = 512 - - -def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past): - if self.training: - slopes = torch.Tensor(_get_interleave(self.n_head)) - position_point = ( - torch.arange(seq_length_with_past) - seq_length_with_past + 1 - ) - position_point = ( - position_point.unsqueeze(0) - .unsqueeze(0) - .expand(self.n_head, seq_length_with_past, -1) - ) - diag = torch.diag(position_point[0]) - position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose( - -1, -2 - ) - alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point - mask = _buffered_future_mask( - tensor, seq_length_with_past, alibi, self.n_head - ) - else: - if self.first_run: - # Override the default max_cache_pos=4096 for memory considerations - self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE - self.first_run = False - self.register_buffer( - "future_mask", - baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos), - persistent=False, - ) - if seq_length_with_past > self.max_cache_pos: - # When max_cache_pos is not enough for current sequence length, - # increase by MASK_BLOCK_SIZE and recalculate future_mask. - self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE - self.register_buffer( - "future_mask", - baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos), - persistent=False, - ) - mask = self.future_mask[ - : self.n_head, :seq_length_with_past, :seq_length_with_past - ] - return mask