parent
9e763b049c
commit
bbd749dceb
2 changed files with 30 additions and 28 deletions
|
|
@ -29,6 +29,7 @@ class DynamicFp8Cache(DynamicCache):
|
||||||
value_states: torch.Tensor,
|
value_states: torch.Tensor,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
cache_kwargs: Optional[Dict[str, Any]]=None,
|
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||||
|
new_layout=False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
batch_size, num_heads, seq_len, head_dim = key_states.shape
|
batch_size, num_heads, seq_len, head_dim = key_states.shape
|
||||||
|
|
@ -41,15 +42,18 @@ class DynamicFp8Cache(DynamicCache):
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
batch_size, num_heads, seq_len, head_dim,
|
batch_size, num_heads, seq_len, head_dim,
|
||||||
device=key_states.device,
|
device=key_states.device,
|
||||||
|
new_layout=new_layout,
|
||||||
)
|
)
|
||||||
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
|
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states,
|
||||||
|
new_layout=new_layout)
|
||||||
|
|
||||||
self.key_cache.append(k_cache)
|
self.key_cache.append(k_cache)
|
||||||
self.value_cache.append(v_cache)
|
self.value_cache.append(v_cache)
|
||||||
else:
|
else:
|
||||||
k_cache = self.key_cache[layer_idx]
|
k_cache = self.key_cache[layer_idx]
|
||||||
v_cache = self.value_cache[layer_idx]
|
v_cache = self.value_cache[layer_idx]
|
||||||
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
|
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states,
|
||||||
|
new_layout=new_layout)
|
||||||
self.key_cache[layer_idx] = k_cache
|
self.key_cache[layer_idx] = k_cache
|
||||||
self.value_cache[layer_idx] = v_cache
|
self.value_cache[layer_idx] = v_cache
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -354,43 +354,41 @@ def qwen2_attention_forward_quantized(
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, cache_kwargs)
|
self.layer_idx, cache_kwargs,
|
||||||
|
new_layout=True)
|
||||||
|
|
||||||
if q_len != 1:
|
if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
|
||||||
|
and not hidden_states.requires_grad:
|
||||||
|
import linear_q4_0
|
||||||
|
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
|
||||||
|
attention_mask)
|
||||||
|
attn_weights = None
|
||||||
|
else:
|
||||||
key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
|
key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
|
||||||
key = repeat_kv(key, self.num_key_value_groups)
|
key = repeat_kv(key, self.num_key_value_groups)
|
||||||
value = repeat_kv(value, self.num_key_value_groups)
|
value = repeat_kv(value, self.num_key_value_groups)
|
||||||
attn_weights = torch.matmul(query_states, key.transpose(2, 3))
|
attn_weights = torch.matmul(query_states, key.transpose(2, 3))
|
||||||
else:
|
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||||
import linear_q4_0
|
|
||||||
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
|
|
||||||
|
|
||||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
|
||||||
|
("Attention weights should be of size "
|
||||||
|
f"{(bsz, self.num_heads, q_len, kv_seq_len)},"
|
||||||
|
"but is {attn_weights.size()}"))
|
||||||
|
|
||||||
invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
|
if attention_mask is not None:
|
||||||
("Attention weights should be of size "
|
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
||||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)},"
|
(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}"
|
||||||
"but is {attn_weights.size()}"))
|
f" but is {attention_mask.size()}"))
|
||||||
|
|
||||||
if attention_mask is not None:
|
attn_weights = attn_weights + attention_mask
|
||||||
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
|
||||||
(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}"
|
|
||||||
f" but is {attention_mask.size()}"))
|
|
||||||
|
|
||||||
attn_weights = attn_weights + attention_mask
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
||||||
|
training=self.training)
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
|
||||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
|
|
||||||
training=self.training)
|
|
||||||
|
|
||||||
if q_len != 1:
|
|
||||||
attn_output = torch.matmul(attn_weights, value)
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
else:
|
|
||||||
import linear_q4_0
|
|
||||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
|
||||||
value_states.transpose(-1, -2))
|
|
||||||
|
|
||||||
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
|
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
|
||||||
"`attn_output` should be of size "
|
"`attn_output` should be of size "
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue