use fp8 sdp in llama (#10396)

This commit is contained in:
Yishuo Wang 2024-03-13 16:45:38 +08:00 committed by GitHub
parent 60043a3ae8
commit b268baafd6
2 changed files with 32 additions and 89 deletions

View file

@ -383,46 +383,18 @@ def llama_attention_forward_4_31_quantized(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
if not self.training and not hidden_states.requires_grad:
fsdp_flag = use_flash_attention(query_states, key_states, attention_mask)
else:
fsdp_flag = False
if fsdp_flag:
attention_dtype = torch.float16 # use fp16 for flash attention
else:
attention_dtype = original_dtype
# otherwise, use native attention
kv_seq_len = key_states.shape[-2]
if past_key_value is None:
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
kv_seq_len = key_states.shape[-2]
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output, attn_weights = native_sdp(query_states, repeated_key_states,
repeated_value_states, attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads)
if use_cache:
k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_key_value_heads, kv_seq_len, self.head_dim,
device=query_states.device
device=query_states.device, new_layout=True
)
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states)
@ -430,7 +402,7 @@ def llama_attention_forward_4_31_quantized(
else:
k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states)
key_states, value_states, new_layout=True)
kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states)
@ -438,49 +410,16 @@ def llama_attention_forward_4_31_quantized(
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states,
self.num_key_value_groups).to(device, dtype=attention_dtype)
value_states = repeat_kv(value_states,
self.num_key_value_groups).to(device, dtype=attention_dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads)
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
attn_output = torch.matmul(attn_weights, value_states)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size:
invalidInputError(False,
f"`attn_output` should be of size {attn_output_size},"
f" but is {attn_output.size()}")
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states)
attn_weights = None
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

View file

@ -83,32 +83,36 @@ def kv_cache_device_check(x: torch.Tensor) -> bool:
(get_xpu_device_type(x) == "arc" and 1 < x.size(0) and x.size(0) < 8)
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device):
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, new_layout=False):
max_length = current_length + FP8_KV_ALLOC_LENGTH
k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
dtype=torch.uint8, device=device)
v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
dtype=torch.uint8, device=device)
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
k_cache_storage.stride(), storage_offset=0)
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, 0),
v_cache_storage.stride(), storage_offset=0)
return k_cache, v_cache.transpose(-1, -2)
if new_layout:
v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
dtype=torch.uint8, device=device)
v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
v_cache_storage.stride(), storage_offset=0)
return k_cache, v_cache
else:
v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
dtype=torch.uint8, device=device)
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, 0),
v_cache_storage.stride(), storage_offset=0)
return k_cache, v_cache.transpose(-1, -2)
def append_fp8_kv_cache(k_cache, v_cache, key, value):
def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False):
batch_size, num_heads, cur_length, head_dim = k_cache.shape
new_length = cur_length + key.size(2)
new_size = (batch_size, num_heads, new_length, head_dim)
if k_cache.stride(1) < new_length * k_cache.size(3):
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length,
head_dim, key.device)
head_dim, key.device, new_layout)
new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0)
new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0)
new_k_cache[:, :, :cur_length, :] = k_cache