remove new_layout parameter (#10906)
This commit is contained in:
		
							parent
							
								
									fbcd7bc737
								
							
						
					
					
						commit
						d884c62dc4
					
				
					 11 changed files with 25 additions and 38 deletions
				
			
		| 
						 | 
					@ -32,7 +32,6 @@ class DynamicFp8Cache(DynamicCache):
 | 
				
			||||||
        value_states: torch.Tensor,
 | 
					        value_states: torch.Tensor,
 | 
				
			||||||
        layer_idx: int,
 | 
					        layer_idx: int,
 | 
				
			||||||
        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
					        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
				
			||||||
        new_layout=False,
 | 
					 | 
				
			||||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
					    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        batch_size, num_heads, seq_len, head_dim = key_states.shape
 | 
					        batch_size, num_heads, seq_len, head_dim = key_states.shape
 | 
				
			||||||
| 
						 | 
					@ -50,18 +49,15 @@ class DynamicFp8Cache(DynamicCache):
 | 
				
			||||||
            k_cache, v_cache = init_fp8_kv_cache(
 | 
					            k_cache, v_cache = init_fp8_kv_cache(
 | 
				
			||||||
                batch_size, num_heads, seq_len, head_dim,
 | 
					                batch_size, num_heads, seq_len, head_dim,
 | 
				
			||||||
                device=key_states.device,
 | 
					                device=key_states.device,
 | 
				
			||||||
                new_layout=new_layout,
 | 
					 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states,
 | 
					            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
				
			||||||
                                                   new_layout=new_layout)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            self.key_cache.append(k_cache)
 | 
					            self.key_cache.append(k_cache)
 | 
				
			||||||
            self.value_cache.append(v_cache)
 | 
					            self.value_cache.append(v_cache)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            k_cache = self.key_cache[layer_idx]
 | 
					            k_cache = self.key_cache[layer_idx]
 | 
				
			||||||
            v_cache = self.value_cache[layer_idx]
 | 
					            v_cache = self.value_cache[layer_idx]
 | 
				
			||||||
            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states,
 | 
					            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
				
			||||||
                                                   new_layout=new_layout)
 | 
					 | 
				
			||||||
            self.key_cache[layer_idx] = k_cache
 | 
					            self.key_cache[layer_idx] = k_cache
 | 
				
			||||||
            self.value_cache[layer_idx] = v_cache
 | 
					            self.value_cache[layer_idx] = v_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -77,7 +73,6 @@ class DynamicNormalCache(DynamicCache):
 | 
				
			||||||
        value_states: torch.Tensor,
 | 
					        value_states: torch.Tensor,
 | 
				
			||||||
        layer_idx: int,
 | 
					        layer_idx: int,
 | 
				
			||||||
        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
					        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
				
			||||||
        new_layout=False,   # useless, just keep same as DynamicFp8Cache
 | 
					 | 
				
			||||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
					    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        batch_size, num_heads, seq_len, head_dim = key_states.shape
 | 
					        batch_size, num_heads, seq_len, head_dim = key_states.shape
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -128,15 +128,15 @@ def baichuan_attention_forward_7b_quantized(
 | 
				
			||||||
        if use_cache:
 | 
					        if use_cache:
 | 
				
			||||||
            k_cache, v_cache = init_fp8_kv_cache(
 | 
					            k_cache, v_cache = init_fp8_kv_cache(
 | 
				
			||||||
                bsz, self.num_heads, kv_seq_len, self.head_dim,
 | 
					                bsz, self.num_heads, kv_seq_len, self.head_dim,
 | 
				
			||||||
                device=device, new_layout=True
 | 
					                device=device
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states,
 | 
					            key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states,
 | 
				
			||||||
                                                           value_states, new_layout=True)
 | 
					                                                           value_states)
 | 
				
			||||||
            past_key_value = (key_states, value_states)
 | 
					            past_key_value = (key_states, value_states)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        k_cache, v_cache = past_key_value
 | 
					        k_cache, v_cache = past_key_value
 | 
				
			||||||
        key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
 | 
					        key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
 | 
				
			||||||
                                                       key_states, value_states, new_layout=True)
 | 
					                                                       key_states, value_states)
 | 
				
			||||||
        kv_seq_len = key_states.shape[-2]
 | 
					        kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
        past_key_value = (key_states, value_states)
 | 
					        past_key_value = (key_states, value_states)
 | 
				
			||||||
        if query_states.size(2) != 1 or query_states.device.type != 'xpu':
 | 
					        if query_states.size(2) != 1 or query_states.device.type != 'xpu':
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -142,12 +142,12 @@ def baichuan_attention_forward_7b_quantized(
 | 
				
			||||||
        kv_seq_len = key_states.shape[-2]
 | 
					        kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
        k_cache, v_cache = init_fp8_kv_cache(
 | 
					        k_cache, v_cache = init_fp8_kv_cache(
 | 
				
			||||||
            bsz, self.num_heads, kv_seq_len, self.head_dim,
 | 
					            bsz, self.num_heads, kv_seq_len, self.head_dim,
 | 
				
			||||||
            device=device, new_layout=True
 | 
					            device=device
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        k_cache, v_cache = past_key_value
 | 
					        k_cache, v_cache = past_key_value
 | 
				
			||||||
    key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
 | 
					    key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
 | 
				
			||||||
                                                   key_states, value_states, new_layout=True)
 | 
					                                                   key_states, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
					    past_key_value = (key_states, value_states) if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -280,8 +280,7 @@ def chatglm2_quantized_attention_forward_8eb45c(
 | 
				
			||||||
                                                 n_kv_head,
 | 
					                                                 n_kv_head,
 | 
				
			||||||
                                                 seq_len,
 | 
					                                                 seq_len,
 | 
				
			||||||
                                                 head_dim,
 | 
					                                                 head_dim,
 | 
				
			||||||
                                                 query_layer.device,
 | 
					                                                 query_layer.device)
 | 
				
			||||||
                                                 new_layout=True)
 | 
					 | 
				
			||||||
            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
 | 
					            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        k_cache, v_cache = kv_cache
 | 
					        k_cache, v_cache = kv_cache
 | 
				
			||||||
| 
						 | 
					@ -289,8 +288,7 @@ def chatglm2_quantized_attention_forward_8eb45c(
 | 
				
			||||||
        v_cache = v_cache.permute(1, 2, 0, 3)
 | 
					        v_cache = v_cache.permute(1, 2, 0, 3)
 | 
				
			||||||
        # k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]
 | 
					        # k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer,
 | 
					        k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
 | 
				
			||||||
                                               new_layout=True)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if attention_mask is not None:
 | 
					        if attention_mask is not None:
 | 
				
			||||||
            attention_mask = ~attention_mask
 | 
					            attention_mask = ~attention_mask
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -438,7 +438,7 @@ def llama_attention_forward_4_31_quantized(
 | 
				
			||||||
        if use_cache:
 | 
					        if use_cache:
 | 
				
			||||||
            k_cache, v_cache = init_fp8_kv_cache(
 | 
					            k_cache, v_cache = init_fp8_kv_cache(
 | 
				
			||||||
                bsz, self.num_key_value_heads, kv_seq_len, self.head_dim,
 | 
					                bsz, self.num_key_value_heads, kv_seq_len, self.head_dim,
 | 
				
			||||||
                device=query_states.device, new_layout=True
 | 
					                device=query_states.device
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
 | 
					            key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
 | 
				
			||||||
                                                           key_states, value_states)
 | 
					                                                           key_states, value_states)
 | 
				
			||||||
| 
						 | 
					@ -446,7 +446,7 @@ def llama_attention_forward_4_31_quantized(
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        k_cache, v_cache = past_key_value
 | 
					        k_cache, v_cache = past_key_value
 | 
				
			||||||
        key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
 | 
					        key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
 | 
				
			||||||
                                                       key_states, value_states, new_layout=True)
 | 
					                                                       key_states, value_states)
 | 
				
			||||||
        kv_seq_len = key_states.shape[-2]
 | 
					        kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
        past_key_value = (key_states, value_states)
 | 
					        past_key_value = (key_states, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1067,13 +1067,11 @@ def llama_attention_forward_4_36_quantized(
 | 
				
			||||||
        if use_cache:
 | 
					        if use_cache:
 | 
				
			||||||
            cache_kwargs = None
 | 
					            cache_kwargs = None
 | 
				
			||||||
            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
                                                             self.layer_idx, cache_kwargs,
 | 
					                                                             self.layer_idx, cache_kwargs)
 | 
				
			||||||
                                                             new_layout=True)
 | 
					 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        cache_kwargs = None  # Specific to RoPE models
 | 
					        cache_kwargs = None  # Specific to RoPE models
 | 
				
			||||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
                                                         self.layer_idx, cache_kwargs,
 | 
					                                                         self.layer_idx, cache_kwargs)
 | 
				
			||||||
                                                         new_layout=True)
 | 
					 | 
				
			||||||
        kv_seq_len = key_states.shape[-2]
 | 
					        kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
        if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
 | 
					        if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
 | 
				
			||||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
					            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -299,7 +299,7 @@ def mistral_attention_forward_quantized(
 | 
				
			||||||
        if use_cache:
 | 
					        if use_cache:
 | 
				
			||||||
            k_cache, v_cache = init_fp8_kv_cache(
 | 
					            k_cache, v_cache = init_fp8_kv_cache(
 | 
				
			||||||
                bsz, self.num_heads, kv_seq_len, self.head_dim,
 | 
					                bsz, self.num_heads, kv_seq_len, self.head_dim,
 | 
				
			||||||
                device=query_states.device, new_layout=True
 | 
					                device=query_states.device
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
 | 
					            key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
 | 
				
			||||||
                                                           key_states, value_states)
 | 
					                                                           key_states, value_states)
 | 
				
			||||||
| 
						 | 
					@ -307,7 +307,7 @@ def mistral_attention_forward_quantized(
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        k_cache, v_cache = past_key_value
 | 
					        k_cache, v_cache = past_key_value
 | 
				
			||||||
        key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
 | 
					        key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
 | 
				
			||||||
                                                       key_states, value_states, new_layout=True)
 | 
					                                                       key_states, value_states)
 | 
				
			||||||
        kv_seq_len = key_states.shape[-2]
 | 
					        kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
        past_key_value = (key_states, value_states)
 | 
					        past_key_value = (key_states, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -680,13 +680,11 @@ def mistral_attention_forward_4_36_quantized(
 | 
				
			||||||
        if use_cache:
 | 
					        if use_cache:
 | 
				
			||||||
            cache_kwargs = None
 | 
					            cache_kwargs = None
 | 
				
			||||||
            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
                                                             self.layer_idx, cache_kwargs,
 | 
					                                                             self.layer_idx, cache_kwargs)
 | 
				
			||||||
                                                             new_layout=True)
 | 
					 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        cache_kwargs = None  # Specific to RoPE models
 | 
					        cache_kwargs = None  # Specific to RoPE models
 | 
				
			||||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
                                                         self.layer_idx, cache_kwargs,
 | 
					                                                         self.layer_idx, cache_kwargs)
 | 
				
			||||||
                                                         new_layout=True)
 | 
					 | 
				
			||||||
        kv_seq_len = key_states.shape[-2]
 | 
					        kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
        if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
 | 
					        if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
 | 
				
			||||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
					            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -129,7 +129,7 @@ def attention_forward(
 | 
				
			||||||
    invalidInputError(past_key_value is not None,
 | 
					    invalidInputError(past_key_value is not None,
 | 
				
			||||||
                      "`past_key_value` cannot be None")
 | 
					                      "`past_key_value` cannot be None")
 | 
				
			||||||
    key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					    key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
                                                     self.layer_idx, None, new_layout=True)
 | 
					                                                     self.layer_idx, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
					    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
				
			||||||
    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
					    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -446,7 +446,7 @@ def qwen_attention_forward_quantized(
 | 
				
			||||||
            max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					            max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
            k_cache, v_cache = init_fp8_kv_cache(
 | 
					            k_cache, v_cache = init_fp8_kv_cache(
 | 
				
			||||||
                query.size(0), self.num_heads, kv_seq_len, self.head_dim,
 | 
					                query.size(0), self.num_heads, kv_seq_len, self.head_dim,
 | 
				
			||||||
                device=query.device, new_layout=True
 | 
					                device=query.device
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
 | 
					            key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
| 
						 | 
					@ -461,7 +461,7 @@ def qwen_attention_forward_quantized(
 | 
				
			||||||
        v_cache = v_cache.transpose(1, 2)
 | 
					        v_cache = v_cache.transpose(1, 2)
 | 
				
			||||||
        # k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
 | 
					        # k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        key, value = append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=True)
 | 
					        key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        attn_output, attn_weight = core_attn(
 | 
					        attn_output, attn_weight = core_attn(
 | 
				
			||||||
            self, query, key, value, causal_mask, attention_mask, head_mask
 | 
					            self, query, key, value, causal_mask, attention_mask, head_mask
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -358,8 +358,7 @@ def qwen2_attention_forward_quantized(
 | 
				
			||||||
    if past_key_value is not None:
 | 
					    if past_key_value is not None:
 | 
				
			||||||
        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
 | 
					        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
 | 
				
			||||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
                                                         self.layer_idx, cache_kwargs,
 | 
					                                                         self.layer_idx, cache_kwargs)
 | 
				
			||||||
                                                         new_layout=True)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
 | 
					    if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
 | 
				
			||||||
            and not hidden_states.requires_grad:
 | 
					            and not hidden_states.requires_grad:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -132,7 +132,7 @@ def attention_forward(
 | 
				
			||||||
    use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states)
 | 
					    use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					    key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
                                                     self.layer_idx, None, new_layout=True)
 | 
					                                                     self.layer_idx, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if use_quantize_kv and q_len == 1:
 | 
					    if use_quantize_kv and q_len == 1:
 | 
				
			||||||
        import linear_q4_0
 | 
					        import linear_q4_0
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -96,7 +96,7 @@ def kv_cache_device_check(x: torch.Tensor) -> bool:
 | 
				
			||||||
            1 < x.size(0) and x.size(0) <= 8)
 | 
					            1 < x.size(0) and x.size(0) <= 8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, new_layout=False):
 | 
					def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device):
 | 
				
			||||||
    max_length = current_length + FP8_KV_ALLOC_LENGTH
 | 
					    max_length = current_length + FP8_KV_ALLOC_LENGTH
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
 | 
					    k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
 | 
				
			||||||
| 
						 | 
					@ -104,7 +104,6 @@ def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, n
 | 
				
			||||||
    k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
 | 
					    k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
 | 
				
			||||||
                                         k_cache_storage.stride(), storage_offset=0)
 | 
					                                         k_cache_storage.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # ignore `new_layout`, will remove it in next PR
 | 
					 | 
				
			||||||
    v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
 | 
					    v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
 | 
				
			||||||
                                  dtype=torch.uint8, device=device)
 | 
					                                  dtype=torch.uint8, device=device)
 | 
				
			||||||
    v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
 | 
					    v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
 | 
				
			||||||
| 
						 | 
					@ -112,14 +111,14 @@ def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, n
 | 
				
			||||||
    return k_cache, v_cache
 | 
					    return k_cache, v_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False):
 | 
					def append_fp8_kv_cache(k_cache, v_cache, key, value):
 | 
				
			||||||
    batch_size, num_heads, cur_length, head_dim = k_cache.shape
 | 
					    batch_size, num_heads, cur_length, head_dim = k_cache.shape
 | 
				
			||||||
    new_length = cur_length + key.size(2)
 | 
					    new_length = cur_length + key.size(2)
 | 
				
			||||||
    new_size = (batch_size, num_heads, new_length, head_dim)
 | 
					    new_size = (batch_size, num_heads, new_length, head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if k_cache.stride(1) < new_length * k_cache.size(3):
 | 
					    if k_cache.stride(1) < new_length * k_cache.size(3):
 | 
				
			||||||
        new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length,
 | 
					        new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length,
 | 
				
			||||||
                                                     head_dim, key.device, new_layout)
 | 
					                                                     head_dim, key.device)
 | 
				
			||||||
        new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0)
 | 
					        new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0)
 | 
				
			||||||
        new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0)
 | 
					        new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0)
 | 
				
			||||||
        new_k_cache[:, :, :cur_length, :] = k_cache
 | 
					        new_k_cache[:, :, :cur_length, :] = k_cache
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue