diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 3e4f9f60..0ed34b64 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -176,22 +176,44 @@ def optimize(model): elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696: # baichuan2 - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) - from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward - convert_forward(model, - module.Attention, - baichuan_attention_forward - ) + if model.config.hidden_size == 4096: + # baichuan2-7B + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_7b + convert_forward(model, + module.Attention, + baichuan_attention_forward_7b + ) + elif model.config.hidden_size == 5120: + # baichuan2-13B + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b + convert_forward(model, + module.BaichuanAttention, + baichuan_attention_forward_13b + ) elif model.config.model_type == "baichuan": # baichuan1 - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) - from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward - convert_forward(model, - module.Attention, - baichuan_attention_forward - ) + if model.config.hidden_size == 4096: + # baichuan-7B + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward_7b + convert_forward(model, + module.Attention, + baichuan_attention_forward_7b + ) + elif model.config.hidden_size == 5120: + # baichuan-13B + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward_13b + convert_forward(model, + module.BaichuanAttention, + baichuan_attention_forward_13b + ) return model diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan.py b/python/llm/src/bigdl/llm/transformers/models/baichuan.py index fb3f55d1..5d2d735c 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan.py @@ -15,6 +15,8 @@ # This file is adapted from # https://huggingface.co/baichuan-inc/Baichuan-7B/blob/c1a5c7d5b7f50ecc51bb0e08150a9f12e5656756/modeling_baichuan.py +# and +# https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/a4a558127068f2ce965aa56aeb826bf501a68970/modeling_baichuan.py import math @@ -30,7 +32,7 @@ from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_em KV_CACHE_ALLOC_BLOCK_LENGTH = 256 -def baichuan_attention_forward( +def baichuan_attention_forward_7b( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -133,3 +135,90 @@ def baichuan_attention_forward( attn_weights = None return attn_output, attn_weights, past_key_value + + +def baichuan_attention_forward_13b( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + + proj = self.W_pack(hidden_states) + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # if past_key_value is not None: + # # reuse k, v, self_attention + # key_states = torch.cat([past_key_value[0], key_states], dim=2) + # value_states = torch.cat([past_key_value[1], value_states], dim=2) + if past_key_value is not None: + # reuse k, v, self_attention + cache_k = past_key_value[0] + cache_v = past_key_value[1] + if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + # allocate new + new_cache_k, new_cache_v = create_kv_cache(bsz, + self.num_heads, + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + cache_k = new_cache_k + cache_v = new_cache_v + + key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) + + elif use_cache: + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + new_key_states, new_value_states = create_kv_cache(bsz, + self.num_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) + new_key_states[:] = key_states + new_value_states[:] = value_states + key_states = new_key_states + value_states = new_value_states + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + if q_len == 1: # inference with cache + if len(attention_mask.size()) == 4: + attention_mask = attention_mask[:, :, -1:, :] + else: + attention_mask = attention_mask[:, -1:, :] + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 4bc0410a..b1179c55 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -15,7 +15,8 @@ # This file is adapted from # https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/cb7fc748b78b7ea99772e4cf76db155729ce774e/modeling_baichuan.py - +# and +# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py import math from typing import List, Optional, Tuple, Union @@ -43,7 +44,7 @@ except ImportError: KV_CACHE_ALLOC_BLOCK_LENGTH = 256 -def baichuan_attention_forward( +def baichuan_attention_forward_7b( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -133,3 +134,116 @@ def baichuan_attention_forward( attn_weights = None return attn_output, attn_weights, past_key_value + + +def baichuan_attention_forward_13b( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + + proj = self.W_pack(hidden_states) + proj = ( + proj.unflatten(-1, (3, self.hidden_size)) + .unsqueeze(0) + .transpose(0, -2) + .squeeze(-2) + ) + query_states = ( + proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + key_states = ( + proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + value_states = ( + proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # if past_key_value is not None: + # # reuse k, v, self_attention + # key_states = torch.cat([past_key_value[0], key_states], dim=2) + # value_states = torch.cat([past_key_value[1], value_states], dim=2) + if past_key_value is not None: + # reuse k, v, self_attention + cache_k = past_key_value[0] + cache_v = past_key_value[1] + if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + # allocate new + new_cache_k, new_cache_v = create_kv_cache(bsz, + self.num_heads, + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + cache_k = new_cache_k + cache_v = new_cache_v + + key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) + + elif use_cache: + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + new_key_states, new_value_states = create_kv_cache(bsz, + self.num_heads, + self.head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) + new_key_states[:] = key_states + new_value_states[:] = value_states + key_states = new_key_states + value_states = new_value_states + + past_key_value = (key_states, value_states) if use_cache else None + if xops is not None and self.training: + attn_weights = None + # query_states = query_states.transpose(1, 2) + # key_states = key_states.transpose(1, 2) + # value_states = value_states.transpose(1, 2) + # attn_output = xops.memory_efficient_attention( + # query_states, key_states, value_states, attn_bias=attention_mask + # ) + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, + enable_mem_efficient=True): + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, + attn_mask=attention_mask) + attn_output = attn_output.transpose(1, 2) + else: + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attention_mask is not None: + if q_len == 1: # inference with cache + if len(attention_mask.size()) == 4: + attention_mask = attention_mask[:, :, -1:, :] + else: + attention_mask = attention_mask[:, -1:, :] + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value