LLM: add quantize kv support for llama transformer 4.36 (#10298)
* add quantize kv support for llama transformer 4.36 * fix style. * fix style.
This commit is contained in:
parent
57e211dab4
commit
ab9fc2485f
2 changed files with 243 additions and 0 deletions
|
|
@ -742,10 +742,15 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
if version.parse(trans_version) >= version.parse("4.36.0"):
|
||||
# transformers version >= 4.36.0
|
||||
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_36
|
||||
from bigdl.llm.transformers.models.llama import llama_model_forward_4_36
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||
llama_attention_forward_4_36, )
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaModel,
|
||||
llama_model_forward_4_36)
|
||||
else:
|
||||
# transformers version between 4.31.0 - 4.35.2
|
||||
convert_forward(
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xp
|
|||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
|
@ -84,6 +85,37 @@ def get_ipex_version():
|
|||
return _ipex_version
|
||||
|
||||
|
||||
def llama_model_forward_4_36(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
from bigdl.llm.transformers.kv import DynamicFp8Cache
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
|
||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
return LlamaModel.forward(
|
||||
self=self,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
def llama_rms_norm_forward(self, hidden_states):
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
import linear_q4_0
|
||||
|
|
@ -906,6 +938,212 @@ def llama_attention_forward_4_36(
|
|||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||
forward_function = llama_attention_forward_4_36_quantized
|
||||
else:
|
||||
forward_function = llama_attention_forward_4_36_original
|
||||
return forward_function(
|
||||
self=self,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
kwargs=kwargs
|
||||
)
|
||||
|
||||
|
||||
def llama_attention_forward_4_36_quantized(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||
"Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||
qtype = getattr(self.q_proj, "qtype", None)
|
||||
qtype_check = qtype in [SYM_INT4, FP8E5]
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
|
||||
and enough_kv_room and bsz * q_len == 1)
|
||||
if decoding_fast_path:
|
||||
hidden_states = hidden_states.view(1, -1)
|
||||
tmp_cache_k, tmp_cache_v = init_kv_cache(
|
||||
bsz,
|
||||
self.num_key_value_heads,
|
||||
self.head_dim,
|
||||
0,
|
||||
1,
|
||||
dtype=hidden_states.dtype,
|
||||
device=device
|
||||
)
|
||||
import linear_q4_0
|
||||
query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
|
||||
self.q_proj.weight,
|
||||
self.k_proj.weight,
|
||||
self.v_proj.weight,
|
||||
position_ids,
|
||||
tmp_cache_k, tmp_cache_v,
|
||||
self.q_proj.weight.qtype,
|
||||
0,
|
||||
self.head_dim)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len,
|
||||
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len,
|
||||
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
invalidInputError(
|
||||
False,
|
||||
f"The cache structure has changed since version v4.36."
|
||||
f" If you are using {self.__class__.__name__} "
|
||||
f"for auto-regressive decoding with k/v caching,"
|
||||
f" please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
if use_fuse_rope:
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
"llama")
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids, "llama")
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
|
||||
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||
attn_weights = torch.matmul(query_states,
|
||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention weights should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||
f" but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
if use_cache:
|
||||
cache_kwargs = None
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, cache_kwargs)
|
||||
else:
|
||||
cache_kwargs = None # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, cache_kwargs)
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)\
|
||||
.to(device, dtype=query_states.dtype)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)\
|
||||
.to(device, dtype=query_states.dtype)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
|
||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},"
|
||||
f" but is {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||
f" but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights,
|
||||
dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
|
||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||
value_states.transpose(-1, -2))
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)},"
|
||||
f" but is {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size
|
||||
// self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum([F.linear(attn_output[i],
|
||||
o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def llama_attention_forward_4_36_original(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
|
|
|
|||
Loading…
Reference in a new issue