LLama optimize_model to support transformers 4.36 (#9818)

* supoort 4.36

* style

* update

* update

* update
This commit is contained in:
Jiao Wang 2024-01-05 11:30:18 -08:00 committed by GitHub
parent 4269a585b2
commit 248ae7fad2
2 changed files with 228 additions and 14 deletions

View file

@ -432,10 +432,6 @@ def _optimize_post(model, lightweight_bmm=False):
trans_version = transformers.__version__ trans_version = transformers.__version__
if version.parse(trans_version) >= version.parse("4.31.0"): if version.parse(trans_version) >= version.parse("4.31.0"):
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_31,)
convert_forward( convert_forward(
model, model,
transformers.models.llama.modeling_llama.LlamaRMSNorm, transformers.models.llama.modeling_llama.LlamaRMSNorm,
@ -443,17 +439,30 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, convert_forward(model,
transformers.models.llama.modeling_llama.LlamaMLP, transformers.models.llama.modeling_llama.LlamaMLP,
llama_mlp_forward) llama_mlp_forward)
if enable_vllm_se_batching: if version.parse(trans_version) >= version.parse("4.36.0"):
convert_forward( # transformers version >= 4.36.0
model, from bigdl.llm.transformers.models.llama import llama_attention_forward_4_36
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_selective_batching_forward_4_31,
)
convert_forward( convert_forward(
model, model,
transformers.models.llama.modeling_llama.LlamaAttention, transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_selective_batching_forward_4_31, llama_attention_forward_4_36, )
) else:
# transformers version between 4.31.0 - 4.35.2
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_31, )
if enable_vllm_se_batching:
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_selective_batching_forward_4_31,
)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_selective_batching_forward_4_31,
)
else: else:
# todo implement 4.28.0 ~ 4.30.2 # todo implement 4.28.0 ~ 4.30.2
pass pass

View file

@ -32,15 +32,16 @@
# limitations under the License. # limitations under the License.
import torch import torch
import warnings
import importlib import importlib
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple, Union, List from typing import Optional, Tuple, Union, List
import math import math
import os import os
import torch.nn.functional as F import torch.nn.functional as F
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
@ -510,6 +511,210 @@ def llama_attention_selective_batching_forward_4_31(
return attn_output.to(original_dtype), attn_weights, updated_past_key_values return attn_output.to(original_dtype), attn_weights, updated_past_key_values
def llama_attention_forward_4_36(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
"Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
if not self.training and not hidden_states.requires_grad:
fsdp_flag = use_flash_attention(hidden_states)
else:
fsdp_flag = False
if fsdp_flag:
attention_dtype = torch.float16 # use fp16 for flash attention
else:
attention_dtype = original_dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
qtype = getattr(self.q_proj, "qtype", None)
is_q4_0 = qtype == SYM_INT4
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
enough_kv_room and bsz * q_len == 1)
# single batch decoding fast path
# forward_qkv takes will perform QKV projection, rotary position embedding
# and save the key/value states to cache, then return query states and the
# extended key/value cache
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
kv_seq_len = cache_k.shape[-2]
import linear_q4_0
query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
self.q_proj.weight,
self.k_proj.weight,
self.v_proj.weight,
position_ids,
cache_k, cache_v,
self.q_proj.weight.qtype,
kv_seq_len,
self.head_dim)
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
if self.layer_idx == 0:
past_key_value.seen_tokens = kv_seq_len
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
else:
if self.config.pretraining_tp > 1:
key_value_slicing = ((self.num_key_value_heads * self.head_dim) //
self.config.pretraining_tp)
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim)
// self.config.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i])
for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i])
for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(False,
"The cache structure has changed since version v4.36. "
f"If you are using {self.__class__.__name__} for "
"auto-regressive decodingwith k/v caching, please make sure "
"to initialize the attention class with a layer index.")
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
if past_key_value is not None:
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value.seen_tokens += key_states.shape[-2]
# reuse k, v, self_attention
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v
key_states, value_states = append_kv_cache(cache_k,
cache_v,
key_states,
value_states)
# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
if fsdp_flag:
# now only use flash attention for first token
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
key_states,
value_states,
is_causal=True)
attn_weights = None
elif use_esimd_sdp(q_len, self.head_dim, query_states):
import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states.contiguous(),
value_states.contiguous())
attn_output = attn_output.view(query_states.shape)
attn_weights = None
else:
# otherwise, use native attention
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads)
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size:
invalidInputError(False,
f"`attn_output` should be of size {attn_output_size},"
f" but is {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp,
dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output.to(original_dtype), attn_weights, past_key_value
def native_sdp(query, key, value, attention_mask, def native_sdp(query, key, value, attention_mask,
bsz, q_len, kv_seq_len, head_dim, num_heads): bsz, q_len, kv_seq_len, head_dim, num_heads):
attn_weights = torch.matmul(query, attn_weights = torch.matmul(query,