[LLM] Support transformers-v4.36.0 on mistral model (#9744)
* add support transformers-v4.36.0 on mistral model * python/llm/src/bigdl/llm/transformers/models/mistral.py * make the redundant implementation as utils * fix code style * fix * fix style * update with utils enough_kv_room
This commit is contained in:
parent
e36111e713
commit
ba0b939579
2 changed files with 205 additions and 48 deletions
|
|
@ -652,19 +652,34 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
module.MistralRMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
else:
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.mistral import mistral_attention_forward
|
||||
convert_forward(model,
|
||||
module.MistralAttention,
|
||||
mistral_attention_forward
|
||||
)
|
||||
convert_forward(model,
|
||||
module.MistralRMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
convert_forward(model,
|
||||
module.MistralMLP,
|
||||
llama_mlp_forward)
|
||||
if version.parse(trans_version) >= version.parse("4.36.0"):
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.mistral import mistral_attention_forward_4_36
|
||||
convert_forward(model,
|
||||
module.MistralAttention,
|
||||
mistral_attention_forward_4_36
|
||||
)
|
||||
convert_forward(model,
|
||||
module.MistralRMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
convert_forward(model,
|
||||
module.MistralMLP,
|
||||
llama_mlp_forward)
|
||||
else:
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.mistral import mistral_attention_forward
|
||||
convert_forward(model,
|
||||
module.MistralAttention,
|
||||
mistral_attention_forward
|
||||
)
|
||||
convert_forward(model,
|
||||
module.MistralRMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
convert_forward(model,
|
||||
module.MistralMLP,
|
||||
llama_mlp_forward)
|
||||
elif model.config.model_type == "Yi":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
|
|||
|
|
@ -44,7 +44,8 @@ from bigdl.llm.utils.common import invalidInputError
|
|||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
|
||||
apply_rotary_pos_emb_no_cache_xpu
|
||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31
|
||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\
|
||||
is_enough_kv_cache_room_4_36
|
||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
|
@ -75,6 +76,46 @@ def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
|
|||
return q_type == SYM_INT4 and use_fuse_rope and enough_kv_room and bs == 1
|
||||
|
||||
|
||||
def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len,
|
||||
num_heads, head_dim, hidden_size, attention_mask):
|
||||
attn_weights = torch.matmul(
|
||||
query_states,
|
||||
key_states.transpose(2, 3)) / math.sqrt(head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, num_heads, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention weights should be of size {(bsz, num_heads, q_len, kv_seq_len)},"
|
||||
f" but is {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||
f" but is {attention_mask.size()}"
|
||||
)
|
||||
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.\
|
||||
softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, num_heads, q_len, head_dim):
|
||||
invalidInputError(
|
||||
f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)},"
|
||||
f" but is {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def mistral_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
@ -177,40 +218,141 @@ def mistral_attention_forward(
|
|||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(
|
||||
query_states,
|
||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},"
|
||||
f" but is {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||
f" but is {attention_mask.size()}"
|
||||
)
|
||||
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.\
|
||||
softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
invalidInputError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)},"
|
||||
f" but is {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, value_states,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.num_heads, self.head_dim,
|
||||
self.hidden_size, attention_mask)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def mistral_attention_forward_4_36(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor]=None,
|
||||
position_ids: Optional[torch.LongTensor]=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]]=None,
|
||||
output_attentions: bool=False,
|
||||
use_cache: bool=False,
|
||||
padding_mask: Optional[torch.Tensor]=None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
|
||||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
bsz * q_len)
|
||||
|
||||
if decoding_fast_path:
|
||||
hidden_states = hidden_states.view(1, -1)
|
||||
|
||||
cache_k = past_key_value.key_cache[self.layer_idx]
|
||||
cache_v = past_key_value.value_cache[self.layer_idx]
|
||||
|
||||
kv_seq_len = cache_k.shape[-2]
|
||||
|
||||
import linear_q4_0
|
||||
query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
|
||||
self.q_proj.weight,
|
||||
self.k_proj.weight,
|
||||
self.v_proj.weight,
|
||||
position_ids,
|
||||
cache_k, cache_v,
|
||||
self.q_proj.weight.qtype,
|
||||
kv_seq_len,
|
||||
self.head_dim)
|
||||
kv_seq_len += 1
|
||||
|
||||
# update past_key_value's seem_tokens and kv caches.
|
||||
if self.layer_idx == 0:
|
||||
past_key_value.seen_tokens = kv_seq_len
|
||||
past_key_value.key_cache[self.layer_idx] = key_states
|
||||
past_key_value.value_cache[self.layer_idx] = value_states
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len,
|
||||
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len,
|
||||
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
invalidInputError(False,
|
||||
"The cache structure has changed since version v4.36. "
|
||||
f"If you are using {self.__class__.__name__} for "
|
||||
"auto-regressive decodingwith k/v caching, please make sure "
|
||||
"to initialize the attention class with a layer index.")
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
if use_fuse_rope:
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
"mistral")
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids, "mistral")
|
||||
|
||||
if past_key_value is not None:
|
||||
# update the number of seen tokens
|
||||
if self.layer_idx == 0:
|
||||
past_key_value.seen_tokens += key_states.shape[-2]
|
||||
|
||||
# reuse k, v, self_attention
|
||||
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
|
||||
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||
past_key_value.key_cache.append(key_states)
|
||||
past_key_value.value_cache.append(value_states)
|
||||
else:
|
||||
cache_k = past_key_value.key_cache[self.layer_idx]
|
||||
cache_v = past_key_value.value_cache[self.layer_idx]
|
||||
|
||||
if not enough_kv_room:
|
||||
# allocate new
|
||||
new_c_k, new_c_v = extend_kv_cache(bsz,
|
||||
self.num_key_value_heads, # Support GQA
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
dtype=cache_k.dtype,
|
||||
device=device)
|
||||
|
||||
new_c_k[:] = cache_k
|
||||
new_c_v[:] = cache_v
|
||||
cache_k = new_c_k
|
||||
cache_v = new_c_v
|
||||
|
||||
key_states, value_states = append_kv_cache(cache_k, cache_v,
|
||||
key_states, value_states)
|
||||
|
||||
# update past_key_value
|
||||
past_key_value.key_cache[self.layer_idx] = key_states
|
||||
past_key_value.value_cache[self.layer_idx] = value_states
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, value_states,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.num_heads, self.head_dim,
|
||||
self.hidden_size, attention_mask)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue