parent
							
								
									9e763b049c
								
							
						
					
					
						commit
						bbd749dceb
					
				
					 2 changed files with 30 additions and 28 deletions
				
			
		| 
						 | 
				
			
			@ -29,6 +29,7 @@ class DynamicFp8Cache(DynamicCache):
 | 
			
		|||
        value_states: torch.Tensor,
 | 
			
		||||
        layer_idx: int,
 | 
			
		||||
        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
			
		||||
        new_layout=False,
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
 | 
			
		||||
        batch_size, num_heads, seq_len, head_dim = key_states.shape
 | 
			
		||||
| 
						 | 
				
			
			@ -41,15 +42,18 @@ class DynamicFp8Cache(DynamicCache):
 | 
			
		|||
            k_cache, v_cache = init_fp8_kv_cache(
 | 
			
		||||
                batch_size, num_heads, seq_len, head_dim,
 | 
			
		||||
                device=key_states.device,
 | 
			
		||||
                new_layout=new_layout,
 | 
			
		||||
            )
 | 
			
		||||
            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
			
		||||
            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states,
 | 
			
		||||
                                                   new_layout=new_layout)
 | 
			
		||||
 | 
			
		||||
            self.key_cache.append(k_cache)
 | 
			
		||||
            self.value_cache.append(v_cache)
 | 
			
		||||
        else:
 | 
			
		||||
            k_cache = self.key_cache[layer_idx]
 | 
			
		||||
            v_cache = self.value_cache[layer_idx]
 | 
			
		||||
            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
			
		||||
            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states,
 | 
			
		||||
                                                   new_layout=new_layout)
 | 
			
		||||
            self.key_cache[layer_idx] = k_cache
 | 
			
		||||
            self.value_cache[layer_idx] = v_cache
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -354,17 +354,20 @@ def qwen2_attention_forward_quantized(
 | 
			
		|||
    if past_key_value is not None:
 | 
			
		||||
        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
 | 
			
		||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                         self.layer_idx, cache_kwargs)
 | 
			
		||||
                                                         self.layer_idx, cache_kwargs,
 | 
			
		||||
                                                         new_layout=True)
 | 
			
		||||
 | 
			
		||||
    if q_len != 1:
 | 
			
		||||
    if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
 | 
			
		||||
            and not hidden_states.requires_grad:
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                          attention_mask)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
        key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
 | 
			
		||||
        key = repeat_kv(key, self.num_key_value_groups)
 | 
			
		||||
        value = repeat_kv(value, self.num_key_value_groups)
 | 
			
		||||
        attn_weights = torch.matmul(query_states, key.transpose(2, 3))
 | 
			
		||||
    else:
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
 | 
			
		||||
 | 
			
		||||
        attn_weights = attn_weights / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
        invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
 | 
			
		||||
| 
						 | 
				
			
			@ -385,12 +388,7 @@ def qwen2_attention_forward_quantized(
 | 
			
		|||
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
 | 
			
		||||
                                             training=self.training)
 | 
			
		||||
 | 
			
		||||
    if q_len != 1:
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value)
 | 
			
		||||
    else:
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
 | 
			
		||||
                                                        value_states.transpose(-1, -2))
 | 
			
		||||
 | 
			
		||||
    invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
 | 
			
		||||
                      "`attn_output` should be of size "
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue