qwen2 fp8 cache (#10446)

* qwen2 fp8 cache

* fix style check
This commit is contained in:
Xin Qiu 2024-03-19 08:32:39 +08:00 committed by GitHub
parent 9e763b049c
commit bbd749dceb
2 changed files with 30 additions and 28 deletions

View file

@ -29,6 +29,7 @@ class DynamicFp8Cache(DynamicCache):
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]]=None,
new_layout=False,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, num_heads, seq_len, head_dim = key_states.shape
@ -41,15 +42,18 @@ class DynamicFp8Cache(DynamicCache):
k_cache, v_cache = init_fp8_kv_cache(
batch_size, num_heads, seq_len, head_dim,
device=key_states.device,
new_layout=new_layout,
)
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states,
new_layout=new_layout)
self.key_cache.append(k_cache)
self.value_cache.append(v_cache)
else:
k_cache = self.key_cache[layer_idx]
v_cache = self.value_cache[layer_idx]
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states,
new_layout=new_layout)
self.key_cache[layer_idx] = k_cache
self.value_cache[layer_idx] = v_cache

View file

@ -354,17 +354,20 @@ def qwen2_attention_forward_quantized(
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
self.layer_idx, cache_kwargs,
new_layout=True)
if q_len != 1:
if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
and not hidden_states.requires_grad:
import linear_q4_0
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
attention_mask)
attn_weights = None
else:
key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
key = repeat_kv(key, self.num_key_value_groups)
value = repeat_kv(value, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key.transpose(2, 3))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim)
invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
@ -385,12 +388,7 @@ def qwen2_attention_forward_quantized(
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
if q_len != 1:
attn_output = torch.matmul(attn_weights, value)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
"`attn_output` should be of size "