use fp8 sdp in llama (#10396)
This commit is contained in:
parent
60043a3ae8
commit
b268baafd6
2 changed files with 32 additions and 89 deletions
|
|
@ -383,46 +383,18 @@ def llama_attention_forward_4_31_quantized(
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
cos, sin, position_ids, "llama")
|
cos, sin, position_ids, "llama")
|
||||||
|
|
||||||
if not self.training and not hidden_states.requires_grad:
|
|
||||||
fsdp_flag = use_flash_attention(query_states, key_states, attention_mask)
|
|
||||||
else:
|
|
||||||
fsdp_flag = False
|
|
||||||
if fsdp_flag:
|
|
||||||
attention_dtype = torch.float16 # use fp16 for flash attention
|
|
||||||
else:
|
|
||||||
attention_dtype = original_dtype
|
|
||||||
|
|
||||||
# otherwise, use native attention
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is None:
|
if past_key_value is None:
|
||||||
attn_weights = torch.matmul(query_states,
|
kv_seq_len = key_states.shape[-2]
|
||||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
attn_output, attn_weights = native_sdp(query_states, repeated_key_states,
|
||||||
invalidInputError(
|
repeated_value_states, attention_mask,
|
||||||
False,
|
bsz, q_len, kv_seq_len,
|
||||||
f"Attention weights should be of size "
|
self.head_dim, self.num_heads)
|
||||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
||||||
f" {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
||||||
invalidInputError(
|
|
||||||
False,
|
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
|
||||||
f" but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
bsz, self.num_key_value_heads, kv_seq_len, self.head_dim,
|
bsz, self.num_key_value_heads, kv_seq_len, self.head_dim,
|
||||||
device=query_states.device
|
device=query_states.device, new_layout=True
|
||||||
)
|
)
|
||||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||||
key_states, value_states)
|
key_states, value_states)
|
||||||
|
|
@ -430,7 +402,7 @@ def llama_attention_forward_4_31_quantized(
|
||||||
else:
|
else:
|
||||||
k_cache, v_cache = past_key_value
|
k_cache, v_cache = past_key_value
|
||||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||||
key_states, value_states)
|
key_states, value_states, new_layout=True)
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
past_key_value = (key_states, value_states)
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
|
|
@ -438,49 +410,16 @@ def llama_attention_forward_4_31_quantized(
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
query_states.dtype)
|
query_states.dtype)
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states,
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
self.num_key_value_groups).to(device, dtype=attention_dtype)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states,
|
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||||
self.num_key_value_groups).to(device, dtype=attention_dtype)
|
attention_mask,
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
bsz, q_len, kv_seq_len,
|
||||||
|
self.head_dim, self.num_heads)
|
||||||
else:
|
else:
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
|
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states)
|
||||||
|
attn_weights = None
|
||||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
||||||
invalidInputError(
|
|
||||||
False,
|
|
||||||
f"Attention weights should be of size "
|
|
||||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
||||||
f" {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
||||||
invalidInputError(
|
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
|
||||||
f" but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
|
||||||
|
|
||||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
else:
|
|
||||||
import linear_q4_0
|
|
||||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
|
||||||
value_states.transpose(-1, -2))
|
|
||||||
|
|
||||||
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
|
||||||
if attn_output.size() != attn_output_size:
|
|
||||||
invalidInputError(False,
|
|
||||||
f"`attn_output` should be of size {attn_output_size},"
|
|
||||||
f" but is {attn_output.size()}")
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
|
||||||
|
|
@ -83,32 +83,36 @@ def kv_cache_device_check(x: torch.Tensor) -> bool:
|
||||||
(get_xpu_device_type(x) == "arc" and 1 < x.size(0) and x.size(0) < 8)
|
(get_xpu_device_type(x) == "arc" and 1 < x.size(0) and x.size(0) < 8)
|
||||||
|
|
||||||
|
|
||||||
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device):
|
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, new_layout=False):
|
||||||
max_length = current_length + FP8_KV_ALLOC_LENGTH
|
max_length = current_length + FP8_KV_ALLOC_LENGTH
|
||||||
|
|
||||||
k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
|
k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
|
||||||
dtype=torch.uint8, device=device)
|
dtype=torch.uint8, device=device)
|
||||||
|
|
||||||
v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
|
|
||||||
dtype=torch.uint8, device=device)
|
|
||||||
|
|
||||||
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
|
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
|
||||||
k_cache_storage.stride(), storage_offset=0)
|
k_cache_storage.stride(), storage_offset=0)
|
||||||
|
|
||||||
|
if new_layout:
|
||||||
|
v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
|
||||||
|
dtype=torch.uint8, device=device)
|
||||||
|
v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
|
||||||
|
v_cache_storage.stride(), storage_offset=0)
|
||||||
|
return k_cache, v_cache
|
||||||
|
else:
|
||||||
|
v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
|
||||||
|
dtype=torch.uint8, device=device)
|
||||||
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, 0),
|
v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, 0),
|
||||||
v_cache_storage.stride(), storage_offset=0)
|
v_cache_storage.stride(), storage_offset=0)
|
||||||
|
|
||||||
return k_cache, v_cache.transpose(-1, -2)
|
return k_cache, v_cache.transpose(-1, -2)
|
||||||
|
|
||||||
|
|
||||||
def append_fp8_kv_cache(k_cache, v_cache, key, value):
|
def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False):
|
||||||
batch_size, num_heads, cur_length, head_dim = k_cache.shape
|
batch_size, num_heads, cur_length, head_dim = k_cache.shape
|
||||||
new_length = cur_length + key.size(2)
|
new_length = cur_length + key.size(2)
|
||||||
new_size = (batch_size, num_heads, new_length, head_dim)
|
new_size = (batch_size, num_heads, new_length, head_dim)
|
||||||
|
|
||||||
if k_cache.stride(1) < new_length * k_cache.size(3):
|
if k_cache.stride(1) < new_length * k_cache.size(3):
|
||||||
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length,
|
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length,
|
||||||
head_dim, key.device)
|
head_dim, key.device, new_layout)
|
||||||
new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0)
|
new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0)
|
||||||
new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0)
|
new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0)
|
||||||
new_k_cache[:, :, :cur_length, :] = k_cache
|
new_k_cache[:, :, :cur_length, :] = k_cache
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue