diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index f206a77d..91dab9b5 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -383,46 +383,18 @@ def llama_attention_forward_4_31_quantized( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids, "llama") - if not self.training and not hidden_states.requires_grad: - fsdp_flag = use_flash_attention(query_states, key_states, attention_mask) - else: - fsdp_flag = False - if fsdp_flag: - attention_dtype = torch.float16 # use fp16 for flash attention - else: - attention_dtype = original_dtype - - # otherwise, use native attention - kv_seq_len = key_states.shape[-2] if past_key_value is None: - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + kv_seq_len = key_states.shape[-2] + repeated_key_states = repeat_kv(key_states, self.num_key_value_groups) + repeated_value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_output, attn_weights = native_sdp(query_states, repeated_key_states, + repeated_value_states, attention_mask, + bsz, q_len, kv_seq_len, + self.head_dim, self.num_heads) if use_cache: k_cache, v_cache = init_fp8_kv_cache( bsz, self.num_key_value_heads, kv_seq_len, self.head_dim, - device=query_states.device + device=query_states.device, new_layout=True ) key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states) @@ -430,7 +402,7 @@ def llama_attention_forward_4_31_quantized( else: k_cache, v_cache = past_key_value key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, - key_states, value_states) + key_states, value_states, new_layout=True) kv_seq_len = key_states.shape[-2] past_key_value = (key_states, value_states) @@ -438,49 +410,16 @@ def llama_attention_forward_4_31_quantized( key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, - self.num_key_value_groups).to(device, dtype=attention_dtype) - value_states = repeat_kv(value_states, - self.num_key_value_groups).to(device, dtype=attention_dtype) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_output, attn_weights = native_sdp(query_states, key_states, value_states, + attention_mask, + bsz, q_len, kv_seq_len, + self.head_dim, self.num_heads) else: import linear_q4_0 - attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states) - - attn_weights = attn_weights / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - invalidInputError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) - - if query_states.size(2) != 1 or query_states.device.type != 'xpu': - attn_output = torch.matmul(attn_weights, value_states) - else: - import linear_q4_0 - attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, - value_states.transpose(-1, -2)) - - attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) - if attn_output.size() != attn_output_size: - invalidInputError(False, - f"`attn_output` should be of size {attn_output_size}," - f" but is {attn_output.size()}") + attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states) + attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 88e5792f..40230638 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -83,32 +83,36 @@ def kv_cache_device_check(x: torch.Tensor) -> bool: (get_xpu_device_type(x) == "arc" and 1 < x.size(0) and x.size(0) < 8) -def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device): +def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, new_layout=False): max_length = current_length + FP8_KV_ALLOC_LENGTH k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim, dtype=torch.uint8, device=device) - - v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length, - dtype=torch.uint8, device=device) - k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim), k_cache_storage.stride(), storage_offset=0) - v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, 0), - v_cache_storage.stride(), storage_offset=0) - - return k_cache, v_cache.transpose(-1, -2) + if new_layout: + v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim, + dtype=torch.uint8, device=device) + v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim), + v_cache_storage.stride(), storage_offset=0) + return k_cache, v_cache + else: + v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length, + dtype=torch.uint8, device=device) + v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, 0), + v_cache_storage.stride(), storage_offset=0) + return k_cache, v_cache.transpose(-1, -2) -def append_fp8_kv_cache(k_cache, v_cache, key, value): +def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False): batch_size, num_heads, cur_length, head_dim = k_cache.shape new_length = cur_length + key.size(2) new_size = (batch_size, num_heads, new_length, head_dim) if k_cache.stride(1) < new_length * k_cache.size(3): new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length, - head_dim, key.device) + head_dim, key.device, new_layout) new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0) new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0) new_k_cache[:, :, :cur_length, :] = k_cache