mixstral fused qkv and rope (#9724)
* mixstral fused qkv and rope * fix and clean * fix style * update * update * fix * update * fix
This commit is contained in:
		
							parent
							
								
									e4f6e43675
								
							
						
					
					
						commit
						e36111e713
					
				
					 4 changed files with 108 additions and 65 deletions
				
			
		| 
						 | 
					@ -39,7 +39,7 @@ import math
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
					from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
					from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
				
			||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
 | 
					from bigdl.llm.transformers.low_bit_linear import SYM_INT4
 | 
				
			||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
					from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
| 
						 | 
					@ -111,11 +111,6 @@ def llama_mlp_forward(
 | 
				
			||||||
    return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 | 
					    return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def is_enough_kv_cache_room(past_key_value):
 | 
					 | 
				
			||||||
    return past_key_value is not None and \
 | 
					 | 
				
			||||||
        past_key_value[0].stride()[1] > past_key_value[0].size(2) * past_key_value[0].size(3)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def should_use_fuse_rope(self, query_states, position_ids):
 | 
					def should_use_fuse_rope(self, query_states, position_ids):
 | 
				
			||||||
    use_fuse_rope = query_states.device.type == "xpu"
 | 
					    use_fuse_rope = query_states.device.type == "xpu"
 | 
				
			||||||
    use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
 | 
					    use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
 | 
				
			||||||
| 
						 | 
					@ -149,7 +144,7 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
        attention_dtype = original_dtype
 | 
					        attention_dtype = original_dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
    enough_kv_room = is_enough_kv_cache_room(past_key_value)
 | 
					    enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
 | 
				
			||||||
    is_q4_0 = self.q_proj.qtype == SYM_INT4
 | 
					    is_q4_0 = self.q_proj.qtype == SYM_INT4
 | 
				
			||||||
    no_tp = not self.config.pretraining_tp > 1
 | 
					    no_tp = not self.config.pretraining_tp > 1
 | 
				
			||||||
    decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
 | 
					    decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -44,7 +44,7 @@ from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
 | 
					from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
 | 
				
			||||||
    apply_rotary_pos_emb_no_cache_xpu
 | 
					    apply_rotary_pos_emb_no_cache_xpu
 | 
				
			||||||
from bigdl.llm.transformers.models.llama import is_enough_kv_cache_room
 | 
					from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31
 | 
				
			||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
 | 
					from bigdl.llm.transformers.low_bit_linear import SYM_INT4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
| 
						 | 
					@ -89,7 +89,7 @@ def mistral_attention_forward(
 | 
				
			||||||
    device = hidden_states.device
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
    enough_kv_room = is_enough_kv_cache_room(past_key_value)
 | 
					    enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
 | 
				
			||||||
    decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
 | 
					    decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
 | 
				
			||||||
                                                use_fuse_rope,
 | 
					                                                use_fuse_rope,
 | 
				
			||||||
                                                enough_kv_room,
 | 
					                                                enough_kv_room,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -47,7 +47,8 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
 | 
					from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
 | 
				
			||||||
    apply_rotary_pos_emb_no_cache_xpu
 | 
					    apply_rotary_pos_emb_no_cache_xpu, is_enough_kv_cache_room_4_36
 | 
				
			||||||
 | 
					from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
| 
						 | 
					@ -142,69 +143,103 @@ def mixtral_attention_forward(
 | 
				
			||||||
    bsz, q_len, _ = hidden_states.size()
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
    device = hidden_states.device
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    query_states = self.q_proj(hidden_states)
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
    key_states = self.k_proj(hidden_states)
 | 
					    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
				
			||||||
    value_states = self.v_proj(hidden_states)
 | 
					    decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
 | 
				
			||||||
 | 
					                                                use_fuse_rope,
 | 
				
			||||||
 | 
					                                                enough_kv_room,
 | 
				
			||||||
 | 
					                                                bsz * q_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
					    if decoding_fast_path:
 | 
				
			||||||
    key_states = key_states.view(bsz, q_len,
 | 
					        hidden_states = hidden_states.view(1, -1)
 | 
				
			||||||
                                 self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
					        cache_k = past_key_value.key_cache[self.layer_idx]
 | 
				
			||||||
    value_states = value_states.view(bsz, q_len,
 | 
					        cache_v = past_key_value.value_cache[self.layer_idx]
 | 
				
			||||||
                                     self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
					        kv_seq_len = cache_k.shape[-2]
 | 
				
			||||||
 | 
					        import linear_q4_0
 | 
				
			||||||
    kv_seq_len = key_states.shape[-2]
 | 
					        query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
 | 
				
			||||||
    if past_key_value is not None:
 | 
					                                                                         self.q_proj.weight,
 | 
				
			||||||
        if self.layer_idx is None:
 | 
					                                                                         self.k_proj.weight,
 | 
				
			||||||
            invalidInputError(False, "The cache structure has changed since version v4.36. "
 | 
					                                                                         self.v_proj.weight,
 | 
				
			||||||
                                     f"If you are using {self.__class__.__name__} for "
 | 
					                                                                         position_ids,
 | 
				
			||||||
                                     "auto-regressive decodingwith k/v caching, please make sure "
 | 
					                                                                         cache_k, cache_v,
 | 
				
			||||||
                                     "to initialize the attention class with a layer index.")
 | 
					                                                                         self.q_proj.weight.qtype,
 | 
				
			||||||
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
					                                                                         kv_seq_len,
 | 
				
			||||||
 | 
					                                                                         self.head_dim)
 | 
				
			||||||
    if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
 | 
					        kv_seq_len += 1
 | 
				
			||||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
					        # update past_key_value's seem_tokens and kv caches.
 | 
				
			||||||
                                                                     key_states,
 | 
					 | 
				
			||||||
                                                                     position_ids,
 | 
					 | 
				
			||||||
                                                                     "mixtral")
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
					 | 
				
			||||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
					 | 
				
			||||||
                                                        cos, sin, position_ids, "mixtral")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if past_key_value is not None:
 | 
					 | 
				
			||||||
        # update the number of seen tokens
 | 
					 | 
				
			||||||
        if self.layer_idx == 0:
 | 
					        if self.layer_idx == 0:
 | 
				
			||||||
            past_key_value.seen_tokens += key_states.shape[-2]
 | 
					            past_key_value.seen_tokens = kv_seq_len
 | 
				
			||||||
 | 
					        past_key_value.key_cache[self.layer_idx] = key_states
 | 
				
			||||||
 | 
					        past_key_value.value_cache[self.layer_idx] = value_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # reuse k, v, self_attention
 | 
					    else:
 | 
				
			||||||
        # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
 | 
					        query_states = self.q_proj(hidden_states)
 | 
				
			||||||
        if len(past_key_value.key_cache) <= self.layer_idx:
 | 
					        key_states = self.k_proj(hidden_states)
 | 
				
			||||||
            past_key_value.key_cache.append(key_states)
 | 
					        value_states = self.v_proj(hidden_states)
 | 
				
			||||||
            past_key_value.value_cache.append(value_states)
 | 
					
 | 
				
			||||||
 | 
					        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					        key_states = key_states.view(bsz, q_len,
 | 
				
			||||||
 | 
					                                     self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					        value_states = value_states.view(bsz, q_len,
 | 
				
			||||||
 | 
					                                         self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
 | 
					        if past_key_value is not None:
 | 
				
			||||||
 | 
					            if self.layer_idx is None:
 | 
				
			||||||
 | 
					                invalidInputError(False,
 | 
				
			||||||
 | 
					                                  "The cache structure has changed since version v4.36. "
 | 
				
			||||||
 | 
					                                  f"If you are using {self.__class__.__name__} for "
 | 
				
			||||||
 | 
					                                  "auto-regressive decodingwith k/v caching, please make sure "
 | 
				
			||||||
 | 
					                                  "to initialize the attention class with a layer index.")
 | 
				
			||||||
 | 
					            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if use_fuse_rope:
 | 
				
			||||||
 | 
					            query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
				
			||||||
 | 
					                                                                         key_states,
 | 
				
			||||||
 | 
					                                                                         position_ids,
 | 
				
			||||||
 | 
					                                                                         "mixtral")
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            cache_k = past_key_value.key_cache[self.layer_idx]
 | 
					            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
				
			||||||
            cache_v = past_key_value.value_cache[self.layer_idx]
 | 
					            query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
				
			||||||
 | 
					                                                            cos, sin, position_ids, "mixtral")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if past_key_value is not None:
 | 
				
			||||||
                # allocate new
 | 
					            # update the number of seen tokens
 | 
				
			||||||
                new_cache_k, new_cache_v = extend_kv_cache(bsz,
 | 
					            if self.layer_idx == 0:
 | 
				
			||||||
                                                           self.num_key_value_heads,  # Support GQA
 | 
					                past_key_value.seen_tokens += key_states.shape[-2]
 | 
				
			||||||
                                                           self.head_dim,
 | 
					 | 
				
			||||||
                                                           cache_k.size(2),
 | 
					 | 
				
			||||||
                                                           kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
					 | 
				
			||||||
                                                           dtype=cache_k.dtype,
 | 
					 | 
				
			||||||
                                                           device=device)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                new_cache_k[:] = cache_k
 | 
					            # reuse k, v, self_attention
 | 
				
			||||||
                new_cache_v[:] = cache_v
 | 
					            # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
 | 
				
			||||||
                cache_k = new_cache_k
 | 
					            if len(past_key_value.key_cache) <= self.layer_idx:
 | 
				
			||||||
                cache_v = new_cache_v
 | 
					                past_key_value.key_cache.append(key_states)
 | 
				
			||||||
 | 
					                past_key_value.value_cache.append(value_states)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                cache_k = past_key_value.key_cache[self.layer_idx]
 | 
				
			||||||
 | 
					                cache_v = past_key_value.value_cache[self.layer_idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
 | 
					                if not enough_kv_room:
 | 
				
			||||||
 | 
					                    # allocate new
 | 
				
			||||||
 | 
					                    new_c_k, new_c_v = extend_kv_cache(bsz,
 | 
				
			||||||
 | 
					                                                       self.num_key_value_heads,  # Support GQA
 | 
				
			||||||
 | 
					                                                       self.head_dim,
 | 
				
			||||||
 | 
					                                                       cache_k.size(2),
 | 
				
			||||||
 | 
					                                                       kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
 | 
					                                                       dtype=cache_k.dtype,
 | 
				
			||||||
 | 
					                                                       device=device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # update past_key_value
 | 
					                    new_c_k[:] = cache_k
 | 
				
			||||||
            past_key_value.key_cache[self.layer_idx] = key_states
 | 
					                    new_c_v[:] = cache_v
 | 
				
			||||||
            past_key_value.value_cache[self.layer_idx] = value_states
 | 
					                    cache_k = new_c_k
 | 
				
			||||||
 | 
					                    cache_v = new_c_v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                key_states, value_states = append_kv_cache(cache_k,
 | 
				
			||||||
 | 
					                                                           cache_v,
 | 
				
			||||||
 | 
					                                                           key_states,
 | 
				
			||||||
 | 
					                                                           value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # update past_key_value
 | 
				
			||||||
 | 
					                past_key_value.key_cache[self.layer_idx] = key_states
 | 
				
			||||||
 | 
					                past_key_value.value_cache[self.layer_idx] = value_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # repeat k/v heads if n_kv_heads < n_heads
 | 
					    # repeat k/v heads if n_kv_heads < n_heads
 | 
				
			||||||
    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
					    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -106,3 +106,16 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        invalidInputError(False,
 | 
					        invalidInputError(False,
 | 
				
			||||||
                          f"{model_family} is not supported.")
 | 
					                          f"{model_family} is not supported.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_enough_kv_cache_room_4_36(past_key_value, idx):
 | 
				
			||||||
 | 
					    # to determinate if is enough kv cache room in transformers==4.36
 | 
				
			||||||
 | 
					    return past_key_value is not None and len(past_key_value.key_cache) > idx and \
 | 
				
			||||||
 | 
					        past_key_value.key_cache[idx].stride()[1] > past_key_value.key_cache[idx].size(2) * \
 | 
				
			||||||
 | 
					        past_key_value.key_cache[idx].size(3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_enough_kv_cache_room_4_31(past_key_value):
 | 
				
			||||||
 | 
					    # to determinate if is enough kv cache room in transformers between 4.31 and 4.35
 | 
				
			||||||
 | 
					    return past_key_value is not None and \
 | 
				
			||||||
 | 
					        past_key_value[0].stride()[1] > past_key_value[0].size(2) * past_key_value[0].size(3)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue