Add quantize kv cache support for chaglm2/3 (#9996)
This commit is contained in:
		
							parent
							
								
									86055d76d5
								
							
						
					
					
						commit
						bf65548d29
					
				
					 2 changed files with 158 additions and 10 deletions
				
			
		| 
						 | 
				
			
			@ -637,17 +637,12 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
            # chatglm2-6b
 | 
			
		||||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
            module = importlib.import_module(modeling_module_name)
 | 
			
		||||
            from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
 | 
			
		||||
            from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
 | 
			
		||||
            from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward
 | 
			
		||||
            from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
			
		||||
            from bigdl.llm.transformers.models.chatglm2 import chatglm2_model_forward
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.SelfAttention,
 | 
			
		||||
                            chatglm2_attention_forward_8eb45c
 | 
			
		||||
                            )
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.CoreAttention,
 | 
			
		||||
                            core_attn_forward_8eb45c)
 | 
			
		||||
                            chatglm2_attention_forward)
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.ChatGLMModel,
 | 
			
		||||
                            chatglm2_model_forward)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -23,6 +23,8 @@ from typing import Optional, Tuple, List
 | 
			
		|||
import torch.nn.functional as F
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
 | 
			
		||||
    append_fp8_kv_cache, restore_fp8_kv_cache, quantize_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import use_flash_attention
 | 
			
		||||
from bigdl.llm.transformers.models.llama import get_ipex_version
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -78,6 +80,21 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t
 | 
			
		|||
    return torch.cat((x_out2, x_pass), dim=-1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Tensor, torch.Tensor):
 | 
			
		||||
    # key, value's shape: [bs, n_kv_head, seq_len, head_dim] -> [bs, n_head, seq_len, head_dim]
 | 
			
		||||
    batch_size, n_kv_head, seq_len, head_dim = key.shape
 | 
			
		||||
 | 
			
		||||
    key = key.unsqueeze(2)
 | 
			
		||||
    key = key.expand(-1, -1, n_head // n_kv_head, -1, -1)
 | 
			
		||||
    key = key.contiguous().view(batch_size, n_head, seq_len, head_dim)
 | 
			
		||||
 | 
			
		||||
    value = value.unsqueeze(2)
 | 
			
		||||
    value = value.expand(-1, -1, n_head // n_kv_head, -1, -1)
 | 
			
		||||
    value = value.contiguous().view(batch_size, n_head, seq_len, head_dim)
 | 
			
		||||
 | 
			
		||||
    return key, value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
| 
						 | 
				
			
			@ -169,6 +186,142 @@ def chatglm2_model_forward(
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm2_attention_forward(
 | 
			
		||||
    self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
 | 
			
		||||
):
 | 
			
		||||
    if quantize_kv_cache(self.query_key_value, hidden_states):
 | 
			
		||||
        forward_function = chatglm2_quantized_attention_forward_8eb45c
 | 
			
		||||
    else:
 | 
			
		||||
        forward_function = chatglm2_attention_forward_8eb45c
 | 
			
		||||
    return forward_function(
 | 
			
		||||
        self=self,
 | 
			
		||||
        hidden_states=hidden_states,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
        rotary_pos_emb=rotary_pos_emb,
 | 
			
		||||
        kv_cache=kv_cache,
 | 
			
		||||
        use_cache=use_cache
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm2_quantized_attention_forward_8eb45c(
 | 
			
		||||
    self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
 | 
			
		||||
):
 | 
			
		||||
    # hidden_states: [seq_len, bs, head_dim]
 | 
			
		||||
    mixed_x_layer = self.query_key_value(hidden_states)
 | 
			
		||||
 | 
			
		||||
    n_head = self.num_attention_heads_per_partition
 | 
			
		||||
    n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head
 | 
			
		||||
    head_dim = self.hidden_size_per_attention_head
 | 
			
		||||
 | 
			
		||||
    query_layer, key_layer, value_layer = mixed_x_layer.split(
 | 
			
		||||
        [n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim],
 | 
			
		||||
        dim=-1,
 | 
			
		||||
    )
 | 
			
		||||
    query_layer = query_layer.view(query_layer.shape[:-1] + (n_head, head_dim))
 | 
			
		||||
    key_layer = key_layer.view(key_layer.shape[:-1] + (n_kv_head, head_dim))
 | 
			
		||||
    value_layer = value_layer.view(value_layer.shape[:-1] + (n_kv_head, head_dim))
 | 
			
		||||
    # query, key, value's shape: [seq_len, bs, n_head/n_kv_head, head_dim]
 | 
			
		||||
 | 
			
		||||
    # apply relative positional encoding (rotary embedding)
 | 
			
		||||
    if rotary_pos_emb is not None:
 | 
			
		||||
        if len(rotary_pos_emb) == 2 and isinstance(rotary_pos_emb, tuple):
 | 
			
		||||
            # use_fuse_rope, see chatglm2_model_forward
 | 
			
		||||
            cos, sin = rotary_pos_emb
 | 
			
		||||
            rot_dim = cos.shape[-1]
 | 
			
		||||
            query_layer = query_layer.transpose(0, 1)
 | 
			
		||||
            key_layer = key_layer.transpose(0, 1)
 | 
			
		||||
            query_layer_cur = query_layer[..., :rot_dim]
 | 
			
		||||
            key_layer_cur = key_layer[..., :rot_dim]
 | 
			
		||||
            # ipex's apply_rotary_embedding can change the origin storage, so query_layer will get
 | 
			
		||||
            # the result directly.
 | 
			
		||||
            torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur)
 | 
			
		||||
            torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur)
 | 
			
		||||
            query_layer = query_layer.transpose(0, 1)
 | 
			
		||||
            key_layer = key_layer.transpose(0, 1)
 | 
			
		||||
        else:
 | 
			
		||||
            query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb)
 | 
			
		||||
            key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb)
 | 
			
		||||
 | 
			
		||||
    query_layer = query_layer.permute(1, 2, 0, 3)
 | 
			
		||||
    key_layer = key_layer.permute(1, 2, 0, 3)
 | 
			
		||||
    value_layer = value_layer.permute(1, 2, 0, 3)
 | 
			
		||||
    # query, key, value's shape: [bs, n_head/n_kv_head, seq_len, head_dim]
 | 
			
		||||
    batch_size, _, seq_len, _ = query_layer.shape
 | 
			
		||||
 | 
			
		||||
    if kv_cache is None:
 | 
			
		||||
        # first token
 | 
			
		||||
        if self.multi_query_attention:
 | 
			
		||||
            key, value = repeat_kv(key_layer, value_layer, n_head)
 | 
			
		||||
        else:
 | 
			
		||||
            key, value = key_layer, value_layer
 | 
			
		||||
 | 
			
		||||
        if attention_mask is None:
 | 
			
		||||
            context_layer = F.scaled_dot_product_attention(query_layer, key, value, is_causal=True)
 | 
			
		||||
        else:
 | 
			
		||||
            context_layer = F.scaled_dot_product_attention(query_layer, key, value, attention_mask)
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            k_cache, v_cache = init_fp8_kv_cache(batch_size,
 | 
			
		||||
                                                 n_kv_head,
 | 
			
		||||
                                                 head_dim,
 | 
			
		||||
                                                 0,
 | 
			
		||||
                                                 seq_len + KV_CACHE_ALLOC_MIN_LENGTH,
 | 
			
		||||
                                                 query_layer.device)
 | 
			
		||||
            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
 | 
			
		||||
    else:
 | 
			
		||||
        k_cache, v_cache = kv_cache
 | 
			
		||||
        k_cache = k_cache.permute(1, 2, 0, 3)
 | 
			
		||||
        v_cache = v_cache.permute(1, 2, 0, 3)
 | 
			
		||||
        # k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]
 | 
			
		||||
 | 
			
		||||
        kv_seq_len = seq_len + k_cache.size(2)
 | 
			
		||||
        if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
 | 
			
		||||
            k_cache, v_cache = extend_fp8_kv_cache(
 | 
			
		||||
                k_cache, v_cache,
 | 
			
		||||
                kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
			
		||||
                device=query_layer.device,
 | 
			
		||||
            )
 | 
			
		||||
            if query_layer.device.type == 'xpu':
 | 
			
		||||
                torch.xpu.empty_cache()
 | 
			
		||||
        k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
 | 
			
		||||
 | 
			
		||||
        if seq_len != 1:
 | 
			
		||||
            key, value = restore_fp8_kv_cache(k_cache, v_cache, query_layer.dtype)
 | 
			
		||||
            key, value = repeat_kv(key, value, n_head)
 | 
			
		||||
            attn = torch.matmul(query_layer, key.transpose(2, 3)) / math.sqrt(head_dim)
 | 
			
		||||
        else:
 | 
			
		||||
            key, value = k_cache, v_cache
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
            attn = linear_q4_0.query_key_fp8_matmul(query_layer, key) / math.sqrt(head_dim)
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attention_mask = ~attention_mask
 | 
			
		||||
            attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
 | 
			
		||||
                                    device=query_layer.device)
 | 
			
		||||
            if attention_mask.dtype == torch.bool:
 | 
			
		||||
                attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
 | 
			
		||||
            else:
 | 
			
		||||
                attn_bias += attention_mask
 | 
			
		||||
            attn += attn_bias
 | 
			
		||||
        attn = F.softmax(attn, dim=-1, dtype=torch.float32)
 | 
			
		||||
        if seq_len != 1:
 | 
			
		||||
            context_layer = torch.matmul(attn.to(value.dtype), value)
 | 
			
		||||
        else:
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
            context_layer = linear_q4_0.attn_value_fp8_matmul(attn, value.transpose(-1, -2))
 | 
			
		||||
 | 
			
		||||
    # context_layer's shape: [bs, n_head, seq_len, head_dim] -> [seq_len, bs, n_head * head_dim]
 | 
			
		||||
    context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(seq_len, batch_size, -1)
 | 
			
		||||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        kv_cache = (k_cache.permute(2, 0, 1, 3), v_cache.permute(2, 0, 1, 3))
 | 
			
		||||
    else:
 | 
			
		||||
        kv_cache = None
 | 
			
		||||
 | 
			
		||||
    output = self.dense(context_layer)
 | 
			
		||||
 | 
			
		||||
    return output, kv_cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm2_attention_forward_8eb45c(
 | 
			
		||||
        self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
 | 
			
		||||
):
 | 
			
		||||
| 
						 | 
				
			
			@ -354,7 +507,7 @@ def chatglm2_attention_forward_8eb45c(
 | 
			
		|||
                                                         save_length,
 | 
			
		||||
                                                         self.hidden_size_per_attention_head))
 | 
			
		||||
 | 
			
		||||
    context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
 | 
			
		||||
    context_layer = core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask)
 | 
			
		||||
 | 
			
		||||
    # =================
 | 
			
		||||
    # Output. [sq, b, h]
 | 
			
		||||
| 
						 | 
				
			
			@ -365,7 +518,7 @@ def chatglm2_attention_forward_8eb45c(
 | 
			
		|||
    return output, (cache_key_layer.permute(2, 0, 1, 3), cache_value_layer.permute(2, 0, 1, 3))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask):
 | 
			
		||||
def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask):
 | 
			
		||||
    pytorch_major_version = int(torch.__version__.split('.')[0])
 | 
			
		||||
    if pytorch_major_version >= 2:
 | 
			
		||||
        query_layer = query_layer.permute(1, 2, 0, 3)
 | 
			
		||||
| 
						 | 
				
			
			@ -392,7 +545,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
 | 
			
		|||
                             dtype=torch.float32).to(value_layer.dtype)
 | 
			
		||||
            context_layer = torch.matmul(attn, value_layer)
 | 
			
		||||
        context_layer = context_layer.permute(2, 0, 1, 3)
 | 
			
		||||
        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
 | 
			
		||||
        new_context_layer_shape = context_layer.size()[:-2] + (-1,)
 | 
			
		||||
        context_layer = context_layer.reshape(*new_context_layer_shape)
 | 
			
		||||
    else:
 | 
			
		||||
        # Raw attention scores
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue