LLM: fix optimized kv cache for baichuan-13b (#9009)
* fix baichuan 13b * fix style * fix * fix style
This commit is contained in:
parent
c88f6ec457
commit
94a7f8917b
3 changed files with 242 additions and 17 deletions
|
|
@ -176,22 +176,44 @@ def optimize(model):
|
||||||
|
|
||||||
elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
|
elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
|
||||||
# baichuan2
|
# baichuan2
|
||||||
modeling_module_name = model.__class__.__module__
|
if model.config.hidden_size == 4096:
|
||||||
module = importlib.import_module(modeling_module_name)
|
# baichuan2-7B
|
||||||
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward
|
modeling_module_name = model.__class__.__module__
|
||||||
convert_forward(model,
|
module = importlib.import_module(modeling_module_name)
|
||||||
module.Attention,
|
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_7b
|
||||||
baichuan_attention_forward
|
convert_forward(model,
|
||||||
)
|
module.Attention,
|
||||||
|
baichuan_attention_forward_7b
|
||||||
|
)
|
||||||
|
elif model.config.hidden_size == 5120:
|
||||||
|
# baichuan2-13B
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b
|
||||||
|
convert_forward(model,
|
||||||
|
module.BaichuanAttention,
|
||||||
|
baichuan_attention_forward_13b
|
||||||
|
)
|
||||||
|
|
||||||
elif model.config.model_type == "baichuan":
|
elif model.config.model_type == "baichuan":
|
||||||
# baichuan1
|
# baichuan1
|
||||||
modeling_module_name = model.__class__.__module__
|
if model.config.hidden_size == 4096:
|
||||||
module = importlib.import_module(modeling_module_name)
|
# baichuan-7B
|
||||||
from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward
|
modeling_module_name = model.__class__.__module__
|
||||||
convert_forward(model,
|
module = importlib.import_module(modeling_module_name)
|
||||||
module.Attention,
|
from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward_7b
|
||||||
baichuan_attention_forward
|
convert_forward(model,
|
||||||
)
|
module.Attention,
|
||||||
|
baichuan_attention_forward_7b
|
||||||
|
)
|
||||||
|
elif model.config.hidden_size == 5120:
|
||||||
|
# baichuan-13B
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward_13b
|
||||||
|
convert_forward(model,
|
||||||
|
module.BaichuanAttention,
|
||||||
|
baichuan_attention_forward_13b
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,8 @@
|
||||||
|
|
||||||
# This file is adapted from
|
# This file is adapted from
|
||||||
# https://huggingface.co/baichuan-inc/Baichuan-7B/blob/c1a5c7d5b7f50ecc51bb0e08150a9f12e5656756/modeling_baichuan.py
|
# https://huggingface.co/baichuan-inc/Baichuan-7B/blob/c1a5c7d5b7f50ecc51bb0e08150a9f12e5656756/modeling_baichuan.py
|
||||||
|
# and
|
||||||
|
# https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/a4a558127068f2ce965aa56aeb826bf501a68970/modeling_baichuan.py
|
||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
@ -30,7 +32,7 @@ from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_em
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
||||||
|
|
||||||
def baichuan_attention_forward(
|
def baichuan_attention_forward_7b(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
|
@ -133,3 +135,90 @@ def baichuan_attention_forward(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def baichuan_attention_forward_13b(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
device = hidden_states.device
|
||||||
|
|
||||||
|
proj = self.W_pack(hidden_states)
|
||||||
|
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
||||||
|
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
# if past_key_value is not None:
|
||||||
|
# # reuse k, v, self_attention
|
||||||
|
# key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
# value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
cache_k = past_key_value[0]
|
||||||
|
cache_v = past_key_value[1]
|
||||||
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
# allocate new
|
||||||
|
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
cache_k.size(2),
|
||||||
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
|
dtype=cache_k.dtype,
|
||||||
|
device=device)
|
||||||
|
new_cache_k[:] = cache_k
|
||||||
|
new_cache_v[:] = cache_v
|
||||||
|
cache_k = new_cache_k
|
||||||
|
cache_v = new_cache_v
|
||||||
|
|
||||||
|
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
|
||||||
|
|
||||||
|
elif use_cache:
|
||||||
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
|
new_key_states, new_value_states = create_kv_cache(bsz,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
kv_seq_len,
|
||||||
|
max_cache_length,
|
||||||
|
dtype=key_states.dtype,
|
||||||
|
device=device)
|
||||||
|
new_key_states[:] = key_states
|
||||||
|
new_value_states[:] = value_states
|
||||||
|
key_states = new_key_states
|
||||||
|
value_states = new_value_states
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if q_len == 1: # inference with cache
|
||||||
|
if len(attention_mask.size()) == 4:
|
||||||
|
attention_mask = attention_mask[:, :, -1:, :]
|
||||||
|
else:
|
||||||
|
attention_mask = attention_mask[:, -1:, :]
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||||
|
|
||||||
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,8 @@
|
||||||
|
|
||||||
# This file is adapted from
|
# This file is adapted from
|
||||||
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/cb7fc748b78b7ea99772e4cf76db155729ce774e/modeling_baichuan.py
|
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/cb7fc748b78b7ea99772e4cf76db155729ce774e/modeling_baichuan.py
|
||||||
|
# and
|
||||||
|
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
@ -43,7 +44,7 @@ except ImportError:
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
||||||
|
|
||||||
def baichuan_attention_forward(
|
def baichuan_attention_forward_7b(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
|
@ -133,3 +134,116 @@ def baichuan_attention_forward(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def baichuan_attention_forward_13b(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
device = hidden_states.device
|
||||||
|
|
||||||
|
proj = self.W_pack(hidden_states)
|
||||||
|
proj = (
|
||||||
|
proj.unflatten(-1, (3, self.hidden_size))
|
||||||
|
.unsqueeze(0)
|
||||||
|
.transpose(0, -2)
|
||||||
|
.squeeze(-2)
|
||||||
|
)
|
||||||
|
query_states = (
|
||||||
|
proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
)
|
||||||
|
key_states = (
|
||||||
|
proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
)
|
||||||
|
value_states = (
|
||||||
|
proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
# if past_key_value is not None:
|
||||||
|
# # reuse k, v, self_attention
|
||||||
|
# key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
# value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
cache_k = past_key_value[0]
|
||||||
|
cache_v = past_key_value[1]
|
||||||
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
# allocate new
|
||||||
|
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
cache_k.size(2),
|
||||||
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
|
dtype=cache_k.dtype,
|
||||||
|
device=device)
|
||||||
|
new_cache_k[:] = cache_k
|
||||||
|
new_cache_v[:] = cache_v
|
||||||
|
cache_k = new_cache_k
|
||||||
|
cache_v = new_cache_v
|
||||||
|
|
||||||
|
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
|
||||||
|
|
||||||
|
elif use_cache:
|
||||||
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
|
new_key_states, new_value_states = create_kv_cache(bsz,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
kv_seq_len,
|
||||||
|
max_cache_length,
|
||||||
|
dtype=key_states.dtype,
|
||||||
|
device=device)
|
||||||
|
new_key_states[:] = key_states
|
||||||
|
new_value_states[:] = value_states
|
||||||
|
key_states = new_key_states
|
||||||
|
value_states = new_value_states
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
if xops is not None and self.training:
|
||||||
|
attn_weights = None
|
||||||
|
# query_states = query_states.transpose(1, 2)
|
||||||
|
# key_states = key_states.transpose(1, 2)
|
||||||
|
# value_states = value_states.transpose(1, 2)
|
||||||
|
# attn_output = xops.memory_efficient_attention(
|
||||||
|
# query_states, key_states, value_states, attn_bias=attention_mask
|
||||||
|
# )
|
||||||
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True,
|
||||||
|
enable_mem_efficient=True):
|
||||||
|
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states,
|
||||||
|
attn_mask=attention_mask)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(
|
||||||
|
query_states, key_states.transpose(2, 3)
|
||||||
|
) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if q_len == 1: # inference with cache
|
||||||
|
if len(attention_mask.size()) == 4:
|
||||||
|
attention_mask = attention_mask[:, :, -1:, :]
|
||||||
|
else:
|
||||||
|
attention_mask = attention_mask[:, -1:, :]
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = torch.max(
|
||||||
|
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue