mistral decoding_fast_path and fused mlp (#9714)
* mistral decoding_fast_path and fused mlp * meet code review
This commit is contained in:
		
							parent
							
								
									d157f623b6
								
							
						
					
					
						commit
						6c3e698bf1
					
				
					 2 changed files with 92 additions and 53 deletions
				
			
		| 
						 | 
				
			
			@ -662,6 +662,9 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
            convert_forward(model,
 | 
			
		||||
                            module.MistralRMSNorm,
 | 
			
		||||
                            llama_rms_norm_forward)
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.MistralMLP,
 | 
			
		||||
                            llama_mlp_forward)
 | 
			
		||||
    elif model.config.model_type == "Yi":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -44,7 +44,8 @@ from bigdl.llm.utils.common import invalidInputError
 | 
			
		|||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
 | 
			
		||||
    apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
 | 
			
		||||
from bigdl.llm.transformers.models.llama import is_enough_kv_cache_room
 | 
			
		||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -63,6 +64,17 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		|||
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_use_fuse_rope(self, hidden_states, position_ids):
 | 
			
		||||
    use_fuse_rope = hidden_states.device.type == "xpu"
 | 
			
		||||
    use_fuse_rope = use_fuse_rope and not (self.training and hidden_states.requires_grad)
 | 
			
		||||
    use_fuse_rope = use_fuse_rope and position_ids is not None
 | 
			
		||||
    return use_fuse_rope
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
 | 
			
		||||
    return q_type == SYM_INT4 and use_fuse_rope and enough_kv_room and bs == 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def mistral_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			@ -76,64 +88,88 @@ def mistral_attention_forward(
 | 
			
		|||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
    device = hidden_states.device
 | 
			
		||||
 | 
			
		||||
    query_states = self.q_proj(hidden_states)
 | 
			
		||||
    key_states = self.k_proj(hidden_states)
 | 
			
		||||
    value_states = self.v_proj(hidden_states)
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room(past_key_value)
 | 
			
		||||
    decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
 | 
			
		||||
                                                use_fuse_rope,
 | 
			
		||||
                                                enough_kv_room,
 | 
			
		||||
                                                bsz * q_len)
 | 
			
		||||
 | 
			
		||||
    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    key_states = key_states.view(bsz, q_len,
 | 
			
		||||
                                 self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    value_states = value_states.view(bsz, q_len,
 | 
			
		||||
                                     self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
 | 
			
		||||
    if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                     key_states,
 | 
			
		||||
                                                                     position_ids,
 | 
			
		||||
                                                                     "mistral")
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
                                                        cos, sin, position_ids, "mistral")
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        # reuse k, v, self_attention
 | 
			
		||||
    if decoding_fast_path:
 | 
			
		||||
        hidden_states = hidden_states.view(1, -1)
 | 
			
		||||
        kv_seq_len = past_key_value[0].shape[-2]
 | 
			
		||||
        cache_k = past_key_value[0]
 | 
			
		||||
        cache_v = past_key_value[1]
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
			
		||||
                                                       self.num_key_value_heads,  # Support GQA
 | 
			
		||||
                                                       self.head_dim,
 | 
			
		||||
                                                       cache_k.size(2),
 | 
			
		||||
                                                       kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
			
		||||
                                                       dtype=cache_k.dtype,
 | 
			
		||||
                                                       device=device)
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
 | 
			
		||||
                                                                         self.q_proj.weight,
 | 
			
		||||
                                                                         self.k_proj.weight,
 | 
			
		||||
                                                                         self.v_proj.weight,
 | 
			
		||||
                                                                         position_ids,
 | 
			
		||||
                                                                         cache_k, cache_v,
 | 
			
		||||
                                                                         self.q_proj.weight.qtype,
 | 
			
		||||
                                                                         kv_seq_len,
 | 
			
		||||
                                                                         self.head_dim)
 | 
			
		||||
        kv_seq_len += 1
 | 
			
		||||
    else:
 | 
			
		||||
        query_states = self.q_proj(hidden_states)
 | 
			
		||||
        key_states = self.k_proj(hidden_states)
 | 
			
		||||
        value_states = self.v_proj(hidden_states)
 | 
			
		||||
 | 
			
		||||
            new_cache_k[:] = cache_k
 | 
			
		||||
            new_cache_v[:] = cache_v
 | 
			
		||||
            cache_k = new_cache_k
 | 
			
		||||
            cache_v = new_cache_v
 | 
			
		||||
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
        key_states = key_states.view(bsz, q_len,
 | 
			
		||||
                                     self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
        value_states = value_states.view(bsz, q_len,
 | 
			
		||||
                                         self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
        key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
 | 
			
		||||
        kv_seq_len = key_states.shape[-2]
 | 
			
		||||
        if past_key_value is not None:
 | 
			
		||||
            kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
 | 
			
		||||
    elif use_cache:
 | 
			
		||||
        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
        new_key_states, new_value_states = init_kv_cache(bsz,
 | 
			
		||||
                                                         self.num_key_value_heads,
 | 
			
		||||
                                                         self.head_dim,
 | 
			
		||||
                                                         kv_seq_len,
 | 
			
		||||
                                                         max_cache_length,
 | 
			
		||||
                                                         dtype=key_states.dtype,
 | 
			
		||||
                                                         device=device)
 | 
			
		||||
        new_key_states[:] = key_states
 | 
			
		||||
        new_value_states[:] = value_states
 | 
			
		||||
        key_states = new_key_states
 | 
			
		||||
        value_states = new_value_states
 | 
			
		||||
        if use_fuse_rope:
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                         key_states,
 | 
			
		||||
                                                                         position_ids,
 | 
			
		||||
                                                                         "mistral")
 | 
			
		||||
        else:
 | 
			
		||||
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
                                                            cos, sin, position_ids, "mistral")
 | 
			
		||||
 | 
			
		||||
        if past_key_value is not None:
 | 
			
		||||
            # reuse k, v, self_attention
 | 
			
		||||
            cache_k = past_key_value[0]
 | 
			
		||||
            cache_v = past_key_value[1]
 | 
			
		||||
            if not enough_kv_room:
 | 
			
		||||
                # allocate new
 | 
			
		||||
                new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
			
		||||
                                                           self.num_key_value_heads,  # Support GQA
 | 
			
		||||
                                                           self.head_dim,
 | 
			
		||||
                                                           cache_k.size(2),
 | 
			
		||||
                                                           kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
			
		||||
                                                           dtype=cache_k.dtype,
 | 
			
		||||
                                                           device=device)
 | 
			
		||||
 | 
			
		||||
                new_cache_k[:] = cache_k
 | 
			
		||||
                new_cache_v[:] = cache_v
 | 
			
		||||
                cache_k = new_cache_k
 | 
			
		||||
                cache_v = new_cache_v
 | 
			
		||||
 | 
			
		||||
            key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
 | 
			
		||||
 | 
			
		||||
        elif use_cache:
 | 
			
		||||
            max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
            new_key_states, new_value_states = init_kv_cache(bsz,
 | 
			
		||||
                                                             self.num_key_value_heads,
 | 
			
		||||
                                                             self.head_dim,
 | 
			
		||||
                                                             kv_seq_len,
 | 
			
		||||
                                                             max_cache_length,
 | 
			
		||||
                                                             dtype=key_states.dtype,
 | 
			
		||||
                                                             device=device)
 | 
			
		||||
            new_key_states[:] = key_states
 | 
			
		||||
            new_value_states[:] = value_states
 | 
			
		||||
            key_states = new_key_states
 | 
			
		||||
            value_states = new_value_states
 | 
			
		||||
 | 
			
		||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue