[LLM] support quantize kv cache to fp8 (#9812)
This commit is contained in:
		
							parent
							
								
									248ae7fad2
								
							
						
					
					
						commit
						afaa871144
					
				
					 2 changed files with 212 additions and 67 deletions
				
			
		| 
						 | 
					@ -37,7 +37,9 @@ except ImportError:
 | 
				
			||||||
    rearrange = None
 | 
					    rearrange = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import rotate_half
 | 
					from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
 | 
				
			||||||
 | 
					    append_fp8_kv_cache, restore_fp8_kv_cache
 | 
				
			||||||
 | 
					from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
 | 
					from bigdl.llm.utils.common import invalidInputError, invalidOperationError
 | 
				
			||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
					from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -83,29 +85,18 @@ def qwen_attention_forward(
 | 
				
			||||||
    query = self._split_heads(query, self.num_heads, self.head_dim)
 | 
					    query = self._split_heads(query, self.num_heads, self.head_dim)
 | 
				
			||||||
    key = self._split_heads(key, self.num_heads, self.head_dim)
 | 
					    key = self._split_heads(key, self.num_heads, self.head_dim)
 | 
				
			||||||
    value = self._split_heads(value, self.num_heads, self.head_dim)
 | 
					    value = self._split_heads(value, self.num_heads, self.head_dim)
 | 
				
			||||||
 | 
					    # query, key, value's shape: [bs, seq_len, num_heads, head_dim]
 | 
				
			||||||
    kv_seq_len = hidden_states.size()[1]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if rotary_pos_emb_list is not None:
 | 
					    if rotary_pos_emb_list is not None:
 | 
				
			||||||
        cur_len = query.shape[1]
 | 
					        cur_len = query.shape[1]
 | 
				
			||||||
        if len(rotary_pos_emb_list) == 1:
 | 
					        if len(rotary_pos_emb_list) == 1:
 | 
				
			||||||
            if query.device.type == 'xpu':
 | 
					            rotary_pos_emb = rotary_pos_emb_list[0]
 | 
				
			||||||
                cos, sin = rotary_pos_emb_list[0]
 | 
					            rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
 | 
				
			||||||
                cos = cos[:, -cur_len:, :, :]
 | 
					            rotary_pos_emb = (rotary_pos_emb,) * 2
 | 
				
			||||||
                sin = sin[:, -cur_len:, :, :]
 | 
					            q_pos_emb, k_pos_emb = rotary_pos_emb
 | 
				
			||||||
                rot_dim = cos.shape[-1]
 | 
					            # Slice the pos emb for current inference
 | 
				
			||||||
                query_cur = query[..., :rot_dim]
 | 
					            query = apply_rotary_pos_emb(query, q_pos_emb)
 | 
				
			||||||
                key_cur = key[..., :rot_dim]
 | 
					            key = apply_rotary_pos_emb(key, k_pos_emb)
 | 
				
			||||||
                torch.ops.torch_ipex.apply_rotary_embedding(query_cur, sin, cos, query_cur)
 | 
					 | 
				
			||||||
                torch.ops.torch_ipex.apply_rotary_embedding(key_cur, sin, cos, key_cur)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                rotary_pos_emb = rotary_pos_emb_list[0]
 | 
					 | 
				
			||||||
                rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
 | 
					 | 
				
			||||||
                rotary_pos_emb = (rotary_pos_emb,) * 2
 | 
					 | 
				
			||||||
                q_pos_emb, k_pos_emb = rotary_pos_emb
 | 
					 | 
				
			||||||
                # Slice the pos emb for current inference
 | 
					 | 
				
			||||||
                query = apply_rotary_pos_emb(query, q_pos_emb)
 | 
					 | 
				
			||||||
                key = apply_rotary_pos_emb(key, k_pos_emb)
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            query_list = []
 | 
					            query_list = []
 | 
				
			||||||
            key_list = []
 | 
					            key_list = []
 | 
				
			||||||
| 
						 | 
					@ -119,62 +110,106 @@ def qwen_attention_forward(
 | 
				
			||||||
            query = torch.cat(query_list, dim=0)
 | 
					            query = torch.cat(query_list, dim=0)
 | 
				
			||||||
            key = torch.cat(key_list, dim=0)
 | 
					            key = torch.cat(key_list, dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bsz, _, n_heads, head_dim = key.size()
 | 
					    query_size, key_size = query.size(1), key.size(1)
 | 
				
			||||||
 | 
					    kv_seq_len = key_size if layer_past is None else key_size + layer_past[0].size(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if layer_past is not None:
 | 
					    if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training:
 | 
				
			||||||
        cache_k, cache_v = layer_past[0], layer_past[1]
 | 
					        seq_start = kv_seq_len - query_size
 | 
				
			||||||
        cache_k = cache_k.transpose(1, 2)
 | 
					        seq_end = kv_seq_len
 | 
				
			||||||
        cache_v = cache_v.transpose(1, 2)
 | 
					 | 
				
			||||||
        kv_seq_len += cache_k.shape[2]
 | 
					 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					 | 
				
			||||||
            # allocate new
 | 
					 | 
				
			||||||
            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
					 | 
				
			||||||
                                                       self.num_heads,
 | 
					 | 
				
			||||||
                                                       self.head_dim,
 | 
					 | 
				
			||||||
                                                       cache_k.size(2),
 | 
					 | 
				
			||||||
                                                       kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
					 | 
				
			||||||
                                                       dtype=cache_k.dtype,
 | 
					 | 
				
			||||||
                                                       device=hidden_states.device)
 | 
					 | 
				
			||||||
            new_cache_k[:] = cache_k
 | 
					 | 
				
			||||||
            new_cache_v[:] = cache_v
 | 
					 | 
				
			||||||
            cache_k = new_cache_k
 | 
					 | 
				
			||||||
            cache_v = new_cache_v
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        key_states, value_states = append_kv_cache(cache_k, cache_v,
 | 
					 | 
				
			||||||
                                                   key.transpose(1, 2), value.transpose(1, 2))
 | 
					 | 
				
			||||||
        key = key_states
 | 
					 | 
				
			||||||
        value = value_states
 | 
					 | 
				
			||||||
    elif use_cache:
 | 
					 | 
				
			||||||
        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					 | 
				
			||||||
        new_key_states, new_value_states = init_kv_cache(bsz,
 | 
					 | 
				
			||||||
                                                         self.num_heads,
 | 
					 | 
				
			||||||
                                                         self.head_dim,
 | 
					 | 
				
			||||||
                                                         kv_seq_len,
 | 
					 | 
				
			||||||
                                                         max_cache_length,
 | 
					 | 
				
			||||||
                                                         dtype=key.dtype,
 | 
					 | 
				
			||||||
                                                         device=hidden_states.device)
 | 
					 | 
				
			||||||
        new_key_states[:] = key.transpose(1, 2)
 | 
					 | 
				
			||||||
        new_value_states[:] = value.transpose(1, 2)
 | 
					 | 
				
			||||||
        key = new_key_states
 | 
					 | 
				
			||||||
        value = new_value_states
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    query_size, key_size = query.size(1), key.size(2)
 | 
					 | 
				
			||||||
    if key_size > self.seq_length and self.use_logn_attn and not self.training:
 | 
					 | 
				
			||||||
        seq_start = key_size - query_size
 | 
					 | 
				
			||||||
        seq_end = key_size
 | 
					 | 
				
			||||||
        logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
 | 
					        logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
 | 
				
			||||||
        query = query * logn_tensor.expand_as(query)
 | 
					        query = query * logn_tensor.expand_as(query)
 | 
				
			||||||
    if query_size == key_size:
 | 
					    if key_size == kv_seq_len:
 | 
				
			||||||
        causal_mask = torch.tril(
 | 
					        causal_mask = torch.tril(
 | 
				
			||||||
            torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
 | 
					            torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
 | 
				
			||||||
        ).view(1, 1, key_size, key_size)
 | 
					        ).view(1, 1, key_size, key_size)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        causal_mask = None
 | 
					        causal_mask = None
 | 
				
			||||||
    query = query.transpose(1, 2)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_output, attn_weight = self._attn(
 | 
					    if quantize_kv_cache(self.c_attn, hidden_states):
 | 
				
			||||||
        query, key, value, causal_mask, attention_mask, head_mask
 | 
					        query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
 | 
				
			||||||
    )
 | 
					        # query, key, value's shape: [bs, num_heads, seq_len, head_dim]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if layer_past is None:
 | 
				
			||||||
 | 
					            # For first token, use original attn
 | 
				
			||||||
 | 
					            attn_output, attn_weight = self._attn(
 | 
				
			||||||
 | 
					                query, key, value, causal_mask, attention_mask, head_mask
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            if use_cache:
 | 
				
			||||||
 | 
					                max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
 | 
					                k_cache, v_cache = init_fp8_kv_cache(
 | 
				
			||||||
 | 
					                    query.size(0), self.num_heads, self.head_dim,
 | 
				
			||||||
 | 
					                    0, max_cache_length,
 | 
				
			||||||
 | 
					                    device=query.device,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            k_cache, v_cache = layer_past[0], layer_past[1]
 | 
				
			||||||
 | 
					            k_cache = k_cache.transpose(1, 2)
 | 
				
			||||||
 | 
					            v_cache = v_cache.transpose(1, 2)
 | 
				
			||||||
 | 
					            # k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
 | 
				
			||||||
 | 
					                # allocate new
 | 
				
			||||||
 | 
					                k_cache, v_cache = extend_fp8_kv_cache(
 | 
				
			||||||
 | 
					                    k_cache, v_cache,
 | 
				
			||||||
 | 
					                    kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
 | 
					                    device=query.device,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                # empty cache to reduce gpu memory
 | 
				
			||||||
 | 
					                if v_cache.device.type == 'xpu':
 | 
				
			||||||
 | 
					                    torch.xpu.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            attn_output, attn_weight = core_attn(
 | 
				
			||||||
 | 
					                self, query, key, value, causal_mask, attention_mask, head_mask
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        bsz = key.size(0)
 | 
				
			||||||
 | 
					        if layer_past is not None:
 | 
				
			||||||
 | 
					            cache_k, cache_v = layer_past[0], layer_past[1]
 | 
				
			||||||
 | 
					            cache_k = cache_k.transpose(1, 2)
 | 
				
			||||||
 | 
					            cache_v = cache_v.transpose(1, 2)
 | 
				
			||||||
 | 
					            kv_seq_len += cache_k.shape[2]
 | 
				
			||||||
 | 
					            if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
 | 
					                # allocate new
 | 
				
			||||||
 | 
					                new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
				
			||||||
 | 
					                                                           self.num_heads,
 | 
				
			||||||
 | 
					                                                           self.head_dim,
 | 
				
			||||||
 | 
					                                                           cache_k.size(2),
 | 
				
			||||||
 | 
					                                                           kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
 | 
					                                                           dtype=cache_k.dtype,
 | 
				
			||||||
 | 
					                                                           device=hidden_states.device)
 | 
				
			||||||
 | 
					                new_cache_k[:] = cache_k
 | 
				
			||||||
 | 
					                new_cache_v[:] = cache_v
 | 
				
			||||||
 | 
					                cache_k = new_cache_k
 | 
				
			||||||
 | 
					                cache_v = new_cache_v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            key_states, value_states = append_kv_cache(cache_k, cache_v,
 | 
				
			||||||
 | 
					                                                       key.transpose(1, 2), value.transpose(1, 2))
 | 
				
			||||||
 | 
					            key = key_states
 | 
				
			||||||
 | 
					            value = value_states
 | 
				
			||||||
 | 
					        elif use_cache:
 | 
				
			||||||
 | 
					            max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
 | 
					            new_key_states, new_value_states = init_kv_cache(bsz,
 | 
				
			||||||
 | 
					                                                             self.num_heads,
 | 
				
			||||||
 | 
					                                                             self.head_dim,
 | 
				
			||||||
 | 
					                                                             kv_seq_len,
 | 
				
			||||||
 | 
					                                                             max_cache_length,
 | 
				
			||||||
 | 
					                                                             dtype=key.dtype,
 | 
				
			||||||
 | 
					                                                             device=hidden_states.device)
 | 
				
			||||||
 | 
					            new_key_states[:] = key.transpose(1, 2)
 | 
				
			||||||
 | 
					            new_value_states[:] = value.transpose(1, 2)
 | 
				
			||||||
 | 
					            key = new_key_states
 | 
				
			||||||
 | 
					            value = new_value_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        query = query.transpose(1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        attn_output, attn_weight = self._attn(
 | 
				
			||||||
 | 
					            query, key, value, causal_mask, attention_mask, head_mask
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    context_layer = self._merge_heads(
 | 
					    context_layer = self._merge_heads(
 | 
				
			||||||
        attn_output, self.num_heads, self.head_dim
 | 
					        attn_output, self.num_heads, self.head_dim
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					@ -191,6 +226,54 @@ def qwen_attention_forward(
 | 
				
			||||||
    return outputs
 | 
					    return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
 | 
				
			||||||
 | 
					    if query.size(2) != 1 or query.device.type != 'xpu':
 | 
				
			||||||
 | 
					        # We have no CPU fp8 matmul implementation for now, so just upscale to fp32
 | 
				
			||||||
 | 
					        key, value = restore_fp8_kv_cache(key, value, query.dtype)
 | 
				
			||||||
 | 
					        attn_weights = torch.matmul(query, key.transpose(-1, -2))
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        import linear_q4_0
 | 
				
			||||||
 | 
					        attn_weights = linear_q4_0.query_key_fp8_matmul(query, key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self.scale_attn_weights:
 | 
				
			||||||
 | 
					        if self.use_cache_quantization:
 | 
				
			||||||
 | 
					            size_temp = value[0].size(-1)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            size_temp = value.size(-1)
 | 
				
			||||||
 | 
					        attn_weights = attn_weights / (size_temp ** 0.5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mask_value = torch.finfo(attn_weights.dtype).min
 | 
				
			||||||
 | 
					    if causal_mask is not None:
 | 
				
			||||||
 | 
					        attn_weights = torch.where(
 | 
				
			||||||
 | 
					            causal_mask, attn_weights.to(attn_weights.dtype), mask_value
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if attention_mask is not None:
 | 
				
			||||||
 | 
					        attn_weights = attn_weights + attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self.softmax_in_fp32:
 | 
				
			||||||
 | 
					        attn_weights = torch.nn.functional.softmax(attn_weights.float(), dim=-1)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_weights = attn_weights.type(query.dtype)
 | 
				
			||||||
 | 
					    attn_weights = self.attn_dropout(attn_weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if head_mask is not None:
 | 
				
			||||||
 | 
					        attn_weights = attn_weights * head_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if query.size(2) != 1 or query.device.type != 'xpu':
 | 
				
			||||||
 | 
					        # We have no CPU fp8 matmul implementation for now, so just upscale to fp32
 | 
				
			||||||
 | 
					        attn_output = torch.matmul(attn_weights, value)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        import linear_q4_0
 | 
				
			||||||
 | 
					        attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value.transpose(-1, -2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = attn_output.transpose(1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return attn_output, attn_weights
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
					def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
    x_2d = x.view(-1, x.shape[-1])
 | 
					    x_2d = x.view(-1, x.shape[-1])
 | 
				
			||||||
    if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
 | 
					    if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,8 +14,10 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
from bigdl.llm.transformers.utils import get_ipex_version
 | 
					from bigdl.llm.transformers.utils import get_ipex_version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -57,6 +59,66 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
 | 
				
			||||||
    return new_cache_k, new_cache_v
 | 
					    return new_cache_k, new_cache_v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
 | 
				
			||||||
 | 
					    if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
 | 
				
			||||||
 | 
					        return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return x.device.type == 'xpu' and hasattr(linear, "qtype") and \
 | 
				
			||||||
 | 
					            linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def init_fp8_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, device):
 | 
				
			||||||
 | 
					    k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
 | 
				
			||||||
 | 
					                                  dtype=torch.uint8, device=device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    v_cache_storage = torch.empty(batch_size, num_heads, head_dim, max_length,
 | 
				
			||||||
 | 
					                                  dtype=torch.uint8, device=device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    k_cache = k_cache_storage.as_strided((batch_size, num_heads, current_length, head_dim),
 | 
				
			||||||
 | 
					                                         k_cache_storage.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    v_cache = v_cache_storage.as_strided((batch_size, num_heads, head_dim, current_length),
 | 
				
			||||||
 | 
					                                         v_cache_storage.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return k_cache, v_cache.transpose(-1, -2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def extend_fp8_kv_cache(k_cache, v_cache, max_length, device):
 | 
				
			||||||
 | 
					    batch_size, num_heads, cur_length, head_dim = k_cache.shape
 | 
				
			||||||
 | 
					    new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, head_dim,
 | 
				
			||||||
 | 
					                                                 cur_length, max_length, device)
 | 
				
			||||||
 | 
					    new_k_cache[:] = k_cache
 | 
				
			||||||
 | 
					    new_v_cache[:] = v_cache
 | 
				
			||||||
 | 
					    return new_k_cache, new_v_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def append_fp8_kv_cache(k_cache, v_cache, key, value):
 | 
				
			||||||
 | 
					    batch_size, num_heads, cur_length, head_dim = k_cache.shape
 | 
				
			||||||
 | 
					    new_length = cur_length + key.size(2)
 | 
				
			||||||
 | 
					    new_size = (batch_size, num_heads, new_length, head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    new_k_cache = k_cache.as_strided(new_size, k_cache.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					    new_v_cache = v_cache.as_strided(new_size, v_cache.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fp8_key = key.half().view(torch.uint8)[:, :, :, 1::2]
 | 
				
			||||||
 | 
					    new_k_cache[:, :, cur_length:new_length, :] = fp8_key
 | 
				
			||||||
 | 
					    fp8_value = value.half().view(torch.uint8)[:, :, :, 1::2]
 | 
				
			||||||
 | 
					    new_v_cache[:, :, cur_length:new_length, :] = fp8_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return new_k_cache, new_v_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def restore_fp8_kv_cache(k_cache, v_cache, dtype):
 | 
				
			||||||
 | 
					    new_k_cache = torch.full(k_cache.shape, 128, dtype=torch.int16, device=k_cache.device)
 | 
				
			||||||
 | 
					    new_k_cache.view(torch.uint8)[:, :, :, 1::2] = k_cache
 | 
				
			||||||
 | 
					    new_k_cache = new_k_cache.view(torch.half)
 | 
				
			||||||
 | 
					    new_v_cache = torch.full(v_cache.shape, 128, dtype=torch.int16, device=v_cache.device)
 | 
				
			||||||
 | 
					    new_v_cache.view(torch.uint8)[:, :, :, 1::2] = v_cache
 | 
				
			||||||
 | 
					    new_v_cache = new_v_cache.view(torch.half)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return new_k_cache.to(dtype=dtype), new_v_cache.to(dtype=dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def rotate_half(x):
 | 
					def rotate_half(x):
 | 
				
			||||||
    """Rotates half the hidden dims of the input."""
 | 
					    """Rotates half the hidden dims of the input."""
 | 
				
			||||||
    x1 = x[..., :x.shape[-1] // 2]
 | 
					    x1 = x[..., :x.shape[-1] // 2]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue