Refactor baichuan1 7B and 13B (#11258)

This commit is contained in:
Yishuo Wang 2024-06-07 14:29:20 +08:00 committed by GitHub
parent 1aa9c9597a
commit ea0d03fd28
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 253 additions and 822 deletions

View file

@ -680,6 +680,11 @@ def _optimize_pre(model):
if model.lm_head.weight.data.device != "meta": if model.lm_head.weight.data.device != "meta":
norm_weight = nn.functional.normalize(lm_head_weight_data) norm_weight = nn.functional.normalize(lm_head_weight_data)
model.lm_head.weight.data = norm_weight model.lm_head.weight.data = norm_weight
# for baichuan2-7B
if model.config.hidden_size in [4096, 2048]:
from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
model.apply(pre_compute_inv_freq)
# for yuan 2.0 # for yuan 2.0
if model.config.model_type == "yuan": if model.config.model_type == "yuan":
from ipex_llm.transformers.models.yuan import merge_qk from ipex_llm.transformers.models.yuan import merge_qk
@ -703,12 +708,6 @@ def _optimize_pre(model):
model.apply(pre_compute_inv_freq) model.apply(pre_compute_inv_freq)
from ipex_llm.transformers.models.phi3 import split_mlp from ipex_llm.transformers.models.phi3 import split_mlp
model.apply(split_mlp) model.apply(split_mlp)
# for baichuan2
if model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
if model.config.hidden_size in [4096, 2048]:
# baichuan2-7B
from ipex_llm.transformers.models.baichuan2 import pre_compute_inv_freq
model.apply(pre_compute_inv_freq)
# for qwen2 # for qwen2
if model.config.model_type == "qwen2": if model.config.model_type == "qwen2":
from ipex_llm.transformers.models.qwen2 import merge_qkv from ipex_llm.transformers.models.qwen2 import merge_qkv
@ -1125,84 +1124,39 @@ def _optimize_post(model, lightweight_bmm=False):
module.FalconAttention, module.FalconAttention,
falcon_attention_forward falcon_attention_forward
) )
elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
# baichuan2
if model.config.hidden_size in [4096, 2048]:
# baichuan2-7B
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.baichuan2 import baichuan_attention_forward_7b
from ipex_llm.transformers.models.baichuan2 import baichuan_mlp_forward
convert_forward(model,
module.Attention,
baichuan_attention_forward_7b
)
convert_forward(model,
module.RMSNorm,
llama_rms_norm_forward)
convert_forward(model,
module.MLP,
baichuan_mlp_forward)
elif model.config.hidden_size == 5120:
# baichuan2-13B
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
from ipex_llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward
from ipex_llm.transformers.models.baichuan2 import baichuan_mlp_forward
from ipex_llm.transformers.models.baichuan2 import baichuan_13b_get_alibi_mask
convert_forward(model,
module.BaichuanAttention,
baichuan_attention_forward_13b
)
# baichuan2-13B's RMSNorm is a little different
convert_forward(model,
module.RMSNorm,
baichuan_13b_rms_norm_forward)
convert_forward(model,
module.MLP,
baichuan_mlp_forward)
if hasattr(model.model, 'get_alibi_mask_orig'):
# deepspeed rewrite "get_alibi_mask" to support baichuan
# https://github.com/microsoft/DeepSpeed/pull/4721
replace_func(model,
module.BaichuanModel,
"get_alibi_mask_orig",
baichuan_13b_get_alibi_mask)
else:
replace_func(model,
module.BaichuanModel,
"get_alibi_mask",
baichuan_13b_get_alibi_mask)
elif model.config.model_type == "baichuan": elif model.config.model_type == "baichuan":
# baichuan1 modeling_module_name = model.__class__.__module__
if model.config.hidden_size == 4096: module = importlib.import_module(modeling_module_name)
# baichuan-7B from ipex_llm.transformers.models.baichuan import baichuan_mlp_forward
modeling_module_name = model.__class__.__module__ convert_forward(model, module.MLP, baichuan_mlp_forward)
module = importlib.import_module(modeling_module_name)
if model.config.hidden_size in [4096, 2048]:
# baichuan-7B and baichuan2-7B
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_7b from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_7b
convert_forward(model, convert_forward(model, module.Attention, baichuan_attention_forward_7b)
module.Attention, convert_forward(model, module.RMSNorm, llama_rms_norm_forward)
baichuan_attention_forward_7b
)
convert_forward(model,
module.RMSNorm,
llama_rms_norm_forward)
elif model.config.hidden_size == 5120: elif model.config.hidden_size == 5120:
# baichuan-13B # baichuan-13B and baichuan2-13B
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b
from ipex_llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward from ipex_llm.transformers.models.baichuan import baichuan_13b_rms_norm_forward
convert_forward(model, convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b)
module.BaichuanAttention, convert_forward(model, module.RMSNorm, baichuan_13b_rms_norm_forward)
baichuan_attention_forward_13b
) if model.config.vocab_size == 125696:
# baichuan-13B's RMSNorm is a little different # baichaun2-13B
convert_forward(model, from ipex_llm.transformers.models.baichuan import baichuan_13b_get_alibi_mask
module.RMSNorm, if hasattr(model.model, 'get_alibi_mask_orig'):
baichuan_13b_rms_norm_forward) # deepspeed rewrite "get_alibi_mask" to support baichuan
# https://github.com/microsoft/DeepSpeed/pull/4721
replace_func(model,
module.BaichuanModel,
"get_alibi_mask_orig",
baichuan_13b_get_alibi_mask)
else:
replace_func(model,
module.BaichuanModel,
"get_alibi_mask",
baichuan_13b_get_alibi_mask)
elif model.config.model_type == "gpt_neox": elif model.config.model_type == "gpt_neox":
from ipex_llm.transformers.models.gptneox import gptneox_attention_forward from ipex_llm.transformers.models.gptneox import gptneox_attention_forward
convert_forward(model, convert_forward(model,

View file

@ -14,30 +14,61 @@
# limitations under the License. # limitations under the License.
# This file is adapted from # This file is adapted from
# https://huggingface.co/baichuan-inc/Baichuan-7B/blob/c1a5c7d5b7f50ecc51bb0e08150a9f12e5656756/modeling_baichuan.py # https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/cb7fc748b78b7ea99772e4cf76db155729ce774e/modeling_baichuan.py
# and # and
# https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/a4a558127068f2ce965aa56aeb826bf501a68970/modeling_baichuan.py # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
import math import math
from typing import List, Optional, Tuple, Union from typing import Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch.nn import functional as F
import torch.nn.functional as F from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ipex_llm.transformers.models.utils import update_past_key_value
from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
append_kv_cache, is_enough_kv_cache_room_4_31 from ipex_llm.transformers.models.utils import mlp_fusion_check
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ import warnings
restore_fp8_kv_cache, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) def pre_compute_inv_freq(module: torch.nn.Module):
if module.__class__.__name__ == "RotaryEmbedding":
inv_freq = module.inv_freq
del module.inv_freq
module.register_buffer("inv_freq", inv_freq, persistent=False)
def baichuan_13b_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training or hidden_states.requires_grad):
import xe_addons
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = xe_addons.rms_norm(self.weight, x_2d, self.epsilon)
return output.reshape(hidden_states.shape)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
return self.weight * hidden_states.to(input_dtype)
def baichuan_mlp_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
import xe_linear
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
return self.down_proj(xe_linear.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
SILU, qtype
))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def baichuan_attention_forward_7b( def baichuan_attention_forward_7b(
@ -48,269 +79,82 @@ def baichuan_attention_forward_7b(
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ):
if use_quantize_kv_cache(self.W_pack, hidden_states):
forward_function = baichuan_attention_forward_7b_quantized
else:
forward_function = baichuan_attention_forward_7b_origin
return forward_function(
self=self,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache
)
def baichuan_attention_forward_7b_quantized(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
device = hidden_states.device device = hidden_states.device
proj = self.W_pack(hidden_states) qkv = self.W_pack(hidden_states)
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
# batch_size x source_len x hidden_size qkv = qkv.transpose(1, 2)
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states, key_states, value_states = qkv.split([self.num_heads,
# batch_size x target_len x head_size self.num_heads,
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) self.num_heads], dim=1)
# batch_size x source_len x hidden_size
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[2]
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, # IPEX-LLM OPT: fuse rope
key_states, if should_use_fuse_rope(hidden_states, position_ids, self.training):
position_ids, import xe_addons
"baichuan") xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else: else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "baichuan") cos, sin, position_ids, "baichuan")
# [bsz, nh, t, hd] query_states = query_states.to(hidden_states.dtype)
key_states = key_states.to(hidden_states.dtype)
if past_key_value is None: # IPEX-LLM OPT: kv cache and quantize kv
attn_weights = torch.matmul(query_states, use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
key_states.transpose(2, 3)) / math.sqrt(self.head_dim) key_states, value_states = update_past_key_value(
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): past_key_value, key_states, value_states,
invalidInputError( kv_seq_len, use_quantize_kv, device
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}"
f", but is {attn_weights.size()}"
)
if attention_mask is not None:
invalidInputError(
attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
f"but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights,
torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
kv_seq_len = key_states.shape[-2]
if use_cache:
k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_heads, kv_seq_len, self.head_dim,
device=device
)
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states,
value_states)
past_key_value = (key_states, value_states)
else:
k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states)
kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
attn_weights = attn_weights / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}"
f", but is {attn_weights.size()}"
)
if attention_mask is not None:
invalidInputError(
attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
f"but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights,
torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
else:
import xe_addons
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
attn_weights = None
invalidInputError(
attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
f"`attn_output` should be of size "
f"{(bsz, self.num_heads, q_len, self.head_dim)},"
f"but is {attn_output.size()}"
) )
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output.to(hidden_states.dtype), attn_weights, past_key_value
def baichuan_attention_forward_7b_origin(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
proj = self.W_pack(hidden_states)
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
# batch_size x source_len x hidden_size
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# batch_size x target_len x head_size
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# batch_size x source_len x hidden_size
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
enough_kv_room = True
if past_key_value is not None:
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2]
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"baichuan")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "baichuan")
# [bsz, nh, t, hd]
# if past_key_value is not None:
# # reuse k, v, self_attention
# key_states = torch.cat([past_key_value[0], key_states], dim=2)
# value_states = torch.cat([past_key_value[1], value_states], dim=2)
if past_key_value is not None:
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if not enough_kv_room:
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads,
self.head_dim,
kv_seq_len,
max_cache_length,
dtype=key_states.dtype,
device=device)
new_key_states[:] = key_states
new_value_states[:] = value_states
key_states = new_key_states
value_states = new_value_states
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
if self.training:
warnings.warn("xops is not supported on Intel GPU, so just use normal implementation")
# IPEX-LLM OPT: sdp
attn_weights = None
if not self.training and not hidden_states.requires_grad and \ if not self.training and not hidden_states.requires_grad and \
use_flash_attention(query_states, key_states, attention_mask): use_flash_attention(query_states, key_states, attention_mask):
attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16), attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
key_states.to(device, dtype=torch.float16), key_states.to(dtype=torch.float16),
value_states.to(device, dtype=torch.float16), value_states.to(dtype=torch.float16),
is_causal=True) is_causal=True).to(hidden_states.dtype)
attn_weights = None elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
elif not self.training and not hidden_states.requires_grad and \
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import xe_addons import xe_addons
attn_output = xe_addons.sdp(query_states, key_states, value_states, if use_quantize_kv:
attention_mask) attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attn_output = attn_output.view(query_states.shape) attention_mask)
attn_weights = None else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else: else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states, attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim) key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}"
f", but is {attn_weights.size()}")
if attention_mask is not None: if attention_mask is not None:
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights,
torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype) dtype=torch.float32).to(value_states.dtype)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), attn_output = attn_output.transpose(1, 2).contiguous()
f"`attn_output` should be of size "
f"{(bsz, self.num_heads, q_len, self.head_dim)},"
f"but is {attn_output.size()}")
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
@ -318,7 +162,7 @@ def baichuan_attention_forward_7b_origin(
if not output_attentions: if not output_attentions:
attn_weights = None attn_weights = None
return attn_output.to(hidden_states.dtype), attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def baichuan_attention_forward_13b( def baichuan_attention_forward_13b(
@ -329,101 +173,57 @@ def baichuan_attention_forward_13b(
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_quantize_kv_cache(self.W_pack, hidden_states):
forward_function = baichuan_attention_forward_13b_quantized
else:
forward_function = baichuan_attention_forward_13b_origin
return forward_function(
self=self,
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache
)
def baichuan_attention_forward_13b_quantized(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
device = hidden_states.device device = hidden_states.device
proj = self.W_pack(hidden_states) qkv = self.W_pack(hidden_states)
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) qkv = qkv.transpose(1, 2)
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states, key_states, value_states = qkv.split([self.num_heads,
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) self.num_heads,
self.num_heads], dim=1)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[2]
if past_key_value is None: # IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
)
past_key_value = (key_states, value_states) if use_cache else None
if self.training:
warnings.warn("xops is not supported on Intel GPU, so just use normal implementation")
if attention_mask is not None:
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, -q_len:, :]
else:
attention_mask = attention_mask[None, :, -q_len:, :]
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
attn_weights = None
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states, attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim) key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: if attention_mask is not None:
if q_len == 1: # inference with cache
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, -1:, :]
else:
attention_mask = attention_mask[:, -1:, :]
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, attn_weights = attn_weights.to(query_states.dtype)
torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights.to(dtype=value_states.dtype), value_states)
attn_output = torch.matmul(attn_weights, value_states)
kv_seq_len = key_states.shape[-2]
if use_cache:
k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_heads, kv_seq_len, self.head_dim,
device=device
)
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states)
past_key_value = (key_states, value_states)
else:
k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states)
kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
else:
import xe_addons
attn_weights = xe_addons.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim)
if attention_mask is not None:
if q_len == 1: # inference with cache
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, -1:, :]
else:
attention_mask = attention_mask[:, -1:, :]
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights,
torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
attn_output = torch.matmul(attn_weights, value_states)
else:
import xe_addons
attn_output = xe_addons.attn_value_fp8_matmul(attn_weights,
value_states)
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
@ -434,90 +234,92 @@ def baichuan_attention_forward_13b_quantized(
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def baichuan_attention_forward_13b_origin( def _get_interleave(n):
self, def _get_interleave_power_of_2(n):
hidden_states: torch.Tensor, start = 2 ** (-(2 ** -(math.log2(n) - 3)))
attention_mask: Optional[torch.Tensor] = None, ratio = start
past_key_value: Optional[Tuple[torch.Tensor]] = None, return [start * ratio**i for i in range(n)]
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() if math.log2(n).is_integer():
device = hidden_states.device return _get_interleave_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
_get_interleave_power_of_2(closest_power_of_2)
+ _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
proj = self.W_pack(hidden_states)
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] def _fill_with_neg_inf(t):
enough_kv_room = True """FP16-compatible function that fills a tensor with -inf."""
if past_key_value is not None: return t.float().fill_(float("-inf")).type_as(t)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
kv_seq_len += past_key_value[0].shape[-2]
# if past_key_value is not None:
# # reuse k, v, self_attention
# key_states = torch.cat([past_key_value[0], key_states], dim=2)
# value_states = torch.cat([past_key_value[1], value_states], dim=2)
if past_key_value is not None:
# reuse k, v, self_attention
cache_k = past_key_value[0]
cache_v = past_key_value[1]
if not enough_kv_room:
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
_future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
_future_mask = _future_mask.unsqueeze(0) + alibi
new_future_mask = _future_mask.to(tensor)
return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]
elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads,
self.head_dim,
kv_seq_len,
max_cache_length,
dtype=key_states.dtype,
device=device)
new_key_states[:] = key_states
new_value_states[:] = value_states
key_states = new_key_states
value_states = new_value_states
past_key_value = (key_states, value_states) if use_cache else None def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos):
slopes = torch.Tensor(_get_interleave(n_head)).to(tensor.dtype)
position_point = torch.arange(max_pos) - max_pos + 1
position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
diag = torch.diag(position_point[0])
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
alibi = alibi.view(n_head, 1, max_pos)
alibi_mask = torch.triu(
_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1).to(tensor.dtype)
alibi_mask = alibi_mask.unsqueeze(0) + alibi
if tensor.device.type == "xpu":
alibi_mask = alibi_mask.to(tensor.device)
return alibi_mask
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: MASK_BLOCK_SIZE = 512
if q_len == 1: # inference with cache
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, -1:, :]
else:
attention_mask = attention_mask[:, -1:, :]
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value_states) def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past):
if self.training:
attn_output = attn_output.transpose(1, 2) slopes = torch.Tensor(_get_interleave(self.n_head))
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) position_point = (
attn_output = self.o_proj(attn_output) torch.arange(seq_length_with_past) - seq_length_with_past + 1
)
if not output_attentions: position_point = (
attn_weights = None position_point.unsqueeze(0)
.unsqueeze(0)
return attn_output.to(hidden_states.dtype), attn_weights, past_key_value .expand(self.n_head, seq_length_with_past, -1)
)
diag = torch.diag(position_point[0])
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(
-1, -2
)
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
mask = _buffered_future_mask(
tensor, seq_length_with_past, alibi, self.n_head
)
else:
if self.first_run:
# Override the default max_cache_pos=4096 for memory considerations
self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
self.first_run = False
self.register_buffer(
"future_mask",
baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
persistent=False,
)
if seq_length_with_past > self.max_cache_pos:
# When max_cache_pos is not enough for current sequence length,
# increase by MASK_BLOCK_SIZE and recalculate future_mask.
self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
self.register_buffer(
"future_mask",
baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
persistent=False,
)
mask = self.future_mask[
: self.n_head, :seq_length_with_past, :seq_length_with_past
]
return mask

View file

@ -1,325 +0,0 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is adapted from
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/cb7fc748b78b7ea99772e4cf76db155729ce774e/modeling_baichuan.py
# and
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
import math
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from torch.nn import functional as F
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.models.utils import update_past_key_value
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
from ipex_llm.transformers.models.utils import mlp_fusion_check
import warnings
def pre_compute_inv_freq(module: torch.nn.Module):
if module.__class__.__name__ == "RotaryEmbedding":
inv_freq = module.inv_freq
del module.inv_freq
module.register_buffer("inv_freq", inv_freq, persistent=False)
def baichuan_13b_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training or hidden_states.requires_grad):
import xe_addons
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = xe_addons.rms_norm(self.weight, x_2d, self.epsilon)
return output.reshape(hidden_states.shape)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
return self.weight * hidden_states.to(input_dtype)
def baichuan_mlp_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
import xe_linear
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
return self.down_proj(xe_linear.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
SILU, qtype
))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def baichuan_attention_forward_7b(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
):
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
qkv = self.W_pack(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_heads,
self.num_heads], dim=1)
kv_seq_len = key_states.shape[2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[2]
# IPEX-LLM OPT: fuse rope
if should_use_fuse_rope(hidden_states, position_ids, self.training):
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "baichuan")
query_states = query_states.to(hidden_states.dtype)
key_states = key_states.to(hidden_states.dtype)
# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
)
past_key_value = (key_states, value_states) if use_cache else None
if self.training:
warnings.warn("xops is not supported on Intel GPU, so just use normal implementation")
# IPEX-LLM OPT: sdp
attn_weights = None
if not self.training and not hidden_states.requires_grad and \
use_flash_attention(query_states, key_states, attention_mask):
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
key_states.to(dtype=torch.float16),
value_states.to(dtype=torch.float16),
is_causal=True).to(hidden_states.dtype)
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(value_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def baichuan_attention_forward_13b(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
qkv = self.W_pack(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_heads,
self.num_heads], dim=1)
kv_seq_len = key_states.shape[2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[2]
# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
)
past_key_value = (key_states, value_states) if use_cache else None
if self.training:
warnings.warn("xops is not supported on Intel GPU, so just use normal implementation")
if attention_mask is not None:
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, -q_len:, :]
else:
attention_mask = attention_mask[:, None, -q_len:, :]
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
attn_weights = None
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = attn_weights.to(query_states.dtype)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights.to(dtype=value_states.dtype), value_states)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def _get_interleave(n):
def _get_interleave_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return _get_interleave_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
_get_interleave_power_of_2(closest_power_of_2)
+ _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
def _fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float("-inf")).type_as(t)
def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
_future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
_future_mask = _future_mask.unsqueeze(0) + alibi
new_future_mask = _future_mask.to(tensor)
return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]
def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos):
slopes = torch.Tensor(_get_interleave(n_head)).to(tensor.dtype)
position_point = torch.arange(max_pos) - max_pos + 1
position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
diag = torch.diag(position_point[0])
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
alibi = alibi.view(n_head, 1, max_pos)
alibi_mask = torch.triu(
_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1).to(tensor.dtype)
alibi_mask = alibi_mask.unsqueeze(0) + alibi
if tensor.device.type == "xpu":
alibi_mask = alibi_mask.to(tensor.device)
return alibi_mask
MASK_BLOCK_SIZE = 512
def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past):
if self.training:
slopes = torch.Tensor(_get_interleave(self.n_head))
position_point = (
torch.arange(seq_length_with_past) - seq_length_with_past + 1
)
position_point = (
position_point.unsqueeze(0)
.unsqueeze(0)
.expand(self.n_head, seq_length_with_past, -1)
)
diag = torch.diag(position_point[0])
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(
-1, -2
)
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
mask = _buffered_future_mask(
tensor, seq_length_with_past, alibi, self.n_head
)
else:
if self.first_run:
# Override the default max_cache_pos=4096 for memory considerations
self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
self.first_run = False
self.register_buffer(
"future_mask",
baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
persistent=False,
)
if seq_length_with_past > self.max_cache_pos:
# When max_cache_pos is not enough for current sequence length,
# increase by MASK_BLOCK_SIZE and recalculate future_mask.
self.max_cache_pos = seq_length_with_past + MASK_BLOCK_SIZE
self.register_buffer(
"future_mask",
baichuan_13b_gen_alibi_mask(tensor, self.n_head, self.max_cache_pos),
persistent=False,
)
mask = self.future_mask[
: self.n_head, :seq_length_with_past, :seq_length_with_past
]
return mask