LLama optimize_model to support transformers 4.36 (#9818)
* supoort 4.36 * style * update * update * update
This commit is contained in:
parent
4269a585b2
commit
248ae7fad2
2 changed files with 228 additions and 14 deletions
|
|
@ -432,10 +432,6 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
|
|
||||||
trans_version = transformers.__version__
|
trans_version = transformers.__version__
|
||||||
if version.parse(trans_version) >= version.parse("4.31.0"):
|
if version.parse(trans_version) >= version.parse("4.31.0"):
|
||||||
convert_forward(
|
|
||||||
model,
|
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
|
||||||
llama_attention_forward_4_31,)
|
|
||||||
convert_forward(
|
convert_forward(
|
||||||
model,
|
model,
|
||||||
transformers.models.llama.modeling_llama.LlamaRMSNorm,
|
transformers.models.llama.modeling_llama.LlamaRMSNorm,
|
||||||
|
|
@ -443,6 +439,19 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
transformers.models.llama.modeling_llama.LlamaMLP,
|
transformers.models.llama.modeling_llama.LlamaMLP,
|
||||||
llama_mlp_forward)
|
llama_mlp_forward)
|
||||||
|
if version.parse(trans_version) >= version.parse("4.36.0"):
|
||||||
|
# transformers version >= 4.36.0
|
||||||
|
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_36
|
||||||
|
convert_forward(
|
||||||
|
model,
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||||
|
llama_attention_forward_4_36, )
|
||||||
|
else:
|
||||||
|
# transformers version between 4.31.0 - 4.35.2
|
||||||
|
convert_forward(
|
||||||
|
model,
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||||
|
llama_attention_forward_4_31, )
|
||||||
if enable_vllm_se_batching:
|
if enable_vllm_se_batching:
|
||||||
convert_forward(
|
convert_forward(
|
||||||
model,
|
model,
|
||||||
|
|
|
||||||
|
|
@ -32,15 +32,16 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import warnings
|
||||||
import importlib
|
import importlib
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Optional, Tuple, Union, List
|
from typing import Optional, Tuple, Union, List
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb
|
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||||
|
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
|
@ -510,6 +511,210 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
|
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
|
||||||
|
|
||||||
|
|
||||||
|
def llama_attention_forward_4_36(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||||
|
"Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
device = hidden_states.device
|
||||||
|
# for flash attention
|
||||||
|
original_dtype = hidden_states.dtype
|
||||||
|
if not self.training and not hidden_states.requires_grad:
|
||||||
|
fsdp_flag = use_flash_attention(hidden_states)
|
||||||
|
else:
|
||||||
|
fsdp_flag = False
|
||||||
|
if fsdp_flag:
|
||||||
|
attention_dtype = torch.float16 # use fp16 for flash attention
|
||||||
|
else:
|
||||||
|
attention_dtype = original_dtype
|
||||||
|
|
||||||
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||||
|
qtype = getattr(self.q_proj, "qtype", None)
|
||||||
|
is_q4_0 = qtype == SYM_INT4
|
||||||
|
no_tp = not self.config.pretraining_tp > 1
|
||||||
|
decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
|
||||||
|
enough_kv_room and bsz * q_len == 1)
|
||||||
|
|
||||||
|
# single batch decoding fast path
|
||||||
|
# forward_qkv takes will perform QKV projection, rotary position embedding
|
||||||
|
# and save the key/value states to cache, then return query states and the
|
||||||
|
# extended key/value cache
|
||||||
|
if decoding_fast_path:
|
||||||
|
hidden_states = hidden_states.view(1, -1)
|
||||||
|
cache_k = past_key_value.key_cache[self.layer_idx]
|
||||||
|
cache_v = past_key_value.value_cache[self.layer_idx]
|
||||||
|
kv_seq_len = cache_k.shape[-2]
|
||||||
|
import linear_q4_0
|
||||||
|
query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
|
||||||
|
self.q_proj.weight,
|
||||||
|
self.k_proj.weight,
|
||||||
|
self.v_proj.weight,
|
||||||
|
position_ids,
|
||||||
|
cache_k, cache_v,
|
||||||
|
self.q_proj.weight.qtype,
|
||||||
|
kv_seq_len,
|
||||||
|
self.head_dim)
|
||||||
|
kv_seq_len += 1
|
||||||
|
# update past_key_value's seem_tokens and kv caches.
|
||||||
|
if self.layer_idx == 0:
|
||||||
|
past_key_value.seen_tokens = kv_seq_len
|
||||||
|
past_key_value.key_cache[self.layer_idx] = key_states
|
||||||
|
past_key_value.value_cache[self.layer_idx] = value_states
|
||||||
|
|
||||||
|
else:
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
key_value_slicing = ((self.num_key_value_heads * self.head_dim) //
|
||||||
|
self.config.pretraining_tp)
|
||||||
|
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim)
|
||||||
|
// self.config.pretraining_tp, dim=0)
|
||||||
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
|
||||||
|
query_states = [F.linear(hidden_states, query_slices[i])
|
||||||
|
for i in range(self.config.pretraining_tp)]
|
||||||
|
query_states = torch.cat(query_states, dim=-1)
|
||||||
|
|
||||||
|
key_states = [F.linear(hidden_states, key_slices[i])
|
||||||
|
for i in range(self.config.pretraining_tp)]
|
||||||
|
key_states = torch.cat(key_states, dim=-1)
|
||||||
|
|
||||||
|
value_states = [F.linear(hidden_states, value_slices[i])
|
||||||
|
for i in range(self.config.pretraining_tp)]
|
||||||
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len,
|
||||||
|
self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len,
|
||||||
|
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len,
|
||||||
|
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
if self.layer_idx is None:
|
||||||
|
invalidInputError(False,
|
||||||
|
"The cache structure has changed since version v4.36. "
|
||||||
|
f"If you are using {self.__class__.__name__} for "
|
||||||
|
"auto-regressive decodingwith k/v caching, please make sure "
|
||||||
|
"to initialize the attention class with a layer index.")
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
|
if use_fuse_rope:
|
||||||
|
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||||
|
key_states,
|
||||||
|
position_ids,
|
||||||
|
"llama")
|
||||||
|
else:
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
|
cos, sin, position_ids, "llama")
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# update the number of seen tokens
|
||||||
|
if self.layer_idx == 0:
|
||||||
|
past_key_value.seen_tokens += key_states.shape[-2]
|
||||||
|
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
|
||||||
|
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||||
|
past_key_value.key_cache.append(key_states)
|
||||||
|
past_key_value.value_cache.append(value_states)
|
||||||
|
else:
|
||||||
|
cache_k = past_key_value.key_cache[self.layer_idx]
|
||||||
|
cache_v = past_key_value.value_cache[self.layer_idx]
|
||||||
|
|
||||||
|
if not enough_kv_room:
|
||||||
|
# allocate new
|
||||||
|
new_c_k, new_c_v = extend_kv_cache(bsz,
|
||||||
|
self.num_key_value_heads, # Support GQA
|
||||||
|
self.head_dim,
|
||||||
|
cache_k.size(2),
|
||||||
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
|
dtype=cache_k.dtype,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
new_c_k[:] = cache_k
|
||||||
|
new_c_v[:] = cache_v
|
||||||
|
cache_k = new_c_k
|
||||||
|
cache_v = new_c_v
|
||||||
|
|
||||||
|
key_states, value_states = append_kv_cache(cache_k,
|
||||||
|
cache_v,
|
||||||
|
key_states,
|
||||||
|
value_states)
|
||||||
|
|
||||||
|
# update past_key_value
|
||||||
|
past_key_value.key_cache[self.layer_idx] = key_states
|
||||||
|
past_key_value.value_cache[self.layer_idx] = value_states
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||||
|
dtype=attention_dtype)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||||
|
dtype=attention_dtype)
|
||||||
|
|
||||||
|
if fsdp_flag:
|
||||||
|
# now only use flash attention for first token
|
||||||
|
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
is_causal=True)
|
||||||
|
attn_weights = None
|
||||||
|
elif use_esimd_sdp(q_len, self.head_dim, query_states):
|
||||||
|
import linear_fp16_esimd
|
||||||
|
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||||
|
key_states.contiguous(),
|
||||||
|
value_states.contiguous())
|
||||||
|
attn_output = attn_output.view(query_states.shape)
|
||||||
|
attn_weights = None
|
||||||
|
else:
|
||||||
|
# otherwise, use native attention
|
||||||
|
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||||
|
attention_mask,
|
||||||
|
bsz, q_len, kv_seq_len,
|
||||||
|
self.head_dim, self.num_heads)
|
||||||
|
|
||||||
|
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||||
|
if attn_output.size() != attn_output_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"`attn_output` should be of size {attn_output_size},"
|
||||||
|
f" but is {attn_output.size()}")
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||||
|
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp,
|
||||||
|
dim=1)
|
||||||
|
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i])
|
||||||
|
for i in range(self.config.pretraining_tp)])
|
||||||
|
else:
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output.to(original_dtype), attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
def native_sdp(query, key, value, attention_mask,
|
def native_sdp(query, key, value, attention_mask,
|
||||||
bsz, q_len, kv_seq_len, head_dim, num_heads):
|
bsz, q_len, kv_seq_len, head_dim, num_heads):
|
||||||
attn_weights = torch.matmul(query,
|
attn_weights = torch.matmul(query,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue