diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index f1e160a1..1409016a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -40,6 +40,8 @@ import math import os import torch.nn.functional as F from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ + restore_fp8_kv_cache, use_quantize_kv_cache from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu @@ -224,6 +226,226 @@ def llama_attention_forward_4_31( use_cache: bool = False, padding_mask: Optional[torch.LongTensor] = None, **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if use_quantize_kv_cache(self.q_proj, hidden_states): + forward_function = llama_attention_forward_4_31_quantized + else: + forward_function = llama_attention_forward_4_31_original + return forward_function( + self=self, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + kwargs=kwargs + ) + + +def llama_attention_forward_4_31_quantized( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, hidden_size = hidden_states.size() + device = hidden_states.device + # for flash attention + original_dtype = hidden_states.dtype + + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) + qtype = getattr(self.q_proj, "qtype", None) + qtype_check = qtype in [SYM_INT4, FP8E5] + no_tp = not self.config.pretraining_tp > 1 + decoding_fast_path = (no_tp and qtype_check and use_fuse_rope + and enough_kv_room and bsz * q_len == 1) + + # single batch decoding fast path + # forward_qkv takes will perform QKV projection, rotary position embedding + # and save the key/value states to cache, then return query states and the + # extended key/value cache + if decoding_fast_path: + hidden_states = hidden_states.view(1, -1) + tmp_cache_k, tmp_cache_v = init_kv_cache( + bsz, + self.num_key_value_heads, + self.head_dim, + 0, + 1, + dtype=hidden_states.dtype, + device=device + ) + import linear_q4_0 + query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + position_ids, + tmp_cache_k, tmp_cache_v, + self.q_proj.weight.qtype, + 0, + self.head_dim) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, + self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if use_fuse_rope: + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "llama") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "llama") + + if not self.training and not hidden_states.requires_grad: + fsdp_flag = use_flash_attention(query_states, key_states) + else: + fsdp_flag = False + if fsdp_flag: + attention_dtype = torch.float16 # use fp16 for flash attention + else: + attention_dtype = original_dtype + + # otherwise, use native attention + kv_seq_len = key_states.shape[-2] + if past_key_value is None: + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention weights should be of size " + f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + if use_cache: + k_cache, v_cache = init_fp8_kv_cache( + bsz, self.num_key_value_heads, kv_seq_len, self.head_dim, + device=query_states.device + ) + key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, + key_states, value_states) + past_key_value = (key_states, value_states) + else: + k_cache, v_cache = past_key_value + key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, + key_states, value_states) + kv_seq_len = key_states.shape[-2] + past_key_value = (key_states, value_states) + + if query_states.size(2) != 1 or query_states.device.type != 'xpu': + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, + self.num_key_value_groups).to(device, dtype=attention_dtype) + value_states = repeat_kv(value_states, + self.num_key_value_groups).to(device, dtype=attention_dtype) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + else: + import linear_q4_0 + attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states) + + attn_weights = attn_weights / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention weights should be of size " + f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + invalidInputError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + + if query_states.size(2) != 1 or query_states.device.type != 'xpu': + attn_output = torch.matmul(attn_weights, value_states) + else: + import linear_q4_0 + attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, + value_states.transpose(-1, -2)) + + attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) + if attn_output.size() != attn_output_size: + invalidInputError(False, + f"`attn_output` should be of size {attn_output_size}," + f" but is {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, + dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output.to(original_dtype), attn_weights, past_key_value + + +def llama_attention_forward_4_31_original( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, hidden_size = hidden_states.size() device = hidden_states.device @@ -333,13 +555,15 @@ def llama_attention_forward_4_31( cache_v = past_key_value[1] if not enough_kv_room: # allocate new - new_cache_k, new_cache_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) + new_cache_k, new_cache_v = extend_kv_cache( + bsz, + self.num_key_value_heads, # Support GQA + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device + ) new_cache_k[:] = cache_k new_cache_v[:] = cache_v cache_k = new_cache_k