Implement selective batching for vLLM (#9659)
* add control to load hf model * finish initial version of selective_batching * temp * finish * Remove print statement * fix error * Apply yang's optimization * a version that works * We need to check kv_cache passed in, this could be an error. TODO: add fast decoding path * format * temp solution: not batching prefill requests * a version that works for prefill batching * format * a solid version: works normally * a temp version * Solid version: remove redundant functions * fix format * format * solid: add option to enable selective_batching * remove logic for using transformer models * format * format * solid: enable argument VLLM_ENABLE_SELECTIVE_BATCHING * format * finish * format
This commit is contained in:
		
							parent
							
								
									2f36769208
								
							
						
					
					
						commit
						fdf93c9267
					
				
					 4 changed files with 467 additions and 36 deletions
				
			
		| 
						 | 
				
			
			@ -46,6 +46,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		|||
from .utils import logger
 | 
			
		||||
from typing import Union
 | 
			
		||||
import numpy as np
 | 
			
		||||
import os
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -386,6 +387,8 @@ def convert_forward(m, target_m, new_forward):
 | 
			
		|||
def _optimize_post(model, lightweight_bmm=False):
 | 
			
		||||
    from packaging import version
 | 
			
		||||
    from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
 | 
			
		||||
    from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31
 | 
			
		||||
    from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31
 | 
			
		||||
    from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
 | 
			
		||||
    from bigdl.llm.transformers.models.llama import llama_mlp_forward
 | 
			
		||||
    from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
| 
						 | 
				
			
			@ -396,6 +399,10 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
                    "supported for further optimizations")
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
 | 
			
		||||
    enable_vllm_se_batching = vllm_selective_batching is not None
 | 
			
		||||
    enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true"
 | 
			
		||||
 | 
			
		||||
    trans_version = transformers.__version__
 | 
			
		||||
    if version.parse(trans_version) >= version.parse("4.31.0"):
 | 
			
		||||
        convert_forward(
 | 
			
		||||
| 
						 | 
				
			
			@ -409,6 +416,17 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        convert_forward(model,
 | 
			
		||||
                        transformers.models.llama.modeling_llama.LlamaMLP,
 | 
			
		||||
                        llama_mlp_forward)
 | 
			
		||||
        if enable_vllm_se_batching:
 | 
			
		||||
            convert_forward(
 | 
			
		||||
                model,
 | 
			
		||||
                transformers.models.llama.modeling_llama.LlamaModel,
 | 
			
		||||
                llama_model_selective_batching_forward_4_31,
 | 
			
		||||
            )
 | 
			
		||||
            convert_forward(
 | 
			
		||||
                model,
 | 
			
		||||
                transformers.models.llama.modeling_llama.LlamaAttention,
 | 
			
		||||
                llama_attention_selective_batching_forward_4_31,
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
        # todo implement 4.28.0 ~ 4.30.2
 | 
			
		||||
        pass
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -34,15 +34,17 @@
 | 
			
		|||
import torch
 | 
			
		||||
import importlib
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
from typing import Optional, Tuple, Union, List
 | 
			
		||||
import math
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		||||
| 
						 | 
				
			
			@ -191,7 +193,6 @@ def llama_attention_forward_4_31(
 | 
			
		|||
            value_states = [F.linear(hidden_states, value_slices[i])
 | 
			
		||||
                            for i in range(self.config.pretraining_tp)]
 | 
			
		||||
            value_states = torch.cat(value_states, dim=-1)
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            query_states = self.q_proj(hidden_states)
 | 
			
		||||
            key_states = self.k_proj(hidden_states)
 | 
			
		||||
| 
						 | 
				
			
			@ -305,6 +306,167 @@ def llama_attention_forward_4_31(
 | 
			
		|||
    return attn_output.to(original_dtype), attn_weights, past_key_value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_attention_selective_batching_forward_4_31(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
    output_attentions: bool = False,
 | 
			
		||||
    use_cache: bool = False,
 | 
			
		||||
    padding_mask: Optional[torch.LongTensor] = None,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
    device = hidden_states.device
 | 
			
		||||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
    # TODO: consider this later - flash attention
 | 
			
		||||
    # if not self.training and not hidden_states.requires_grad:
 | 
			
		||||
    #     fsdp_flag = check_flash_attention_available(hidden_states)
 | 
			
		||||
    # else:
 | 
			
		||||
    #     fsdp_flag = False
 | 
			
		||||
    # if fsdp_flag and q_len > 1:
 | 
			
		||||
    #     attention_dtype = torch.float16  # use fp16 for flash attention
 | 
			
		||||
    # else:
 | 
			
		||||
    #     attention_dtype = original_dtype
 | 
			
		||||
 | 
			
		||||
    attention_dtype = original_dtype
 | 
			
		||||
 | 
			
		||||
    # TODO: decoding fast path
 | 
			
		||||
    # use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    # enough_kv_room = is_enough_kv_cache_room(past_key_value[0])
 | 
			
		||||
    # is_q4_0 = self.q_proj.qtype == SYM_INT4
 | 
			
		||||
    # no_tp = not self.config.pretraining_tp > 1
 | 
			
		||||
    # decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
 | 
			
		||||
    #                       enough_kv_room and bsz * q_len == 1)
 | 
			
		||||
 | 
			
		||||
    # single batch decoding fast path
 | 
			
		||||
    # forward_qkv takes will perform QKV projection, rotary position embedding
 | 
			
		||||
    # and save the key/value states to cache, then return query states and the
 | 
			
		||||
    # extended key/value cache
 | 
			
		||||
    # if decoding_fast_path:
 | 
			
		||||
    #     hidden_states = hidden_states.view(1, -1)
 | 
			
		||||
    #     kv_seq_len = past_key_value[0].shape[-2]
 | 
			
		||||
    #     cache_k = past_key_value[0]
 | 
			
		||||
    #     cache_v = past_key_value[1]
 | 
			
		||||
    #     import linear_q4_0
 | 
			
		||||
    #     query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
 | 
			
		||||
    #                                                                      self.q_proj.weight,
 | 
			
		||||
    #                                                                      self.k_proj.weight,
 | 
			
		||||
    #                                                                      self.v_proj.weight,
 | 
			
		||||
    #                                                                      position_ids,
 | 
			
		||||
    #                                                                      cache_k, cache_v,
 | 
			
		||||
    #                                                                      self.q_proj.weight.qtype,
 | 
			
		||||
    #                                                                      kv_seq_len,
 | 
			
		||||
    #                                                                      self.head_dim)
 | 
			
		||||
    #     kv_seq_len += 1
 | 
			
		||||
 | 
			
		||||
    # else:
 | 
			
		||||
    if self.config.pretraining_tp > 1:
 | 
			
		||||
        invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet")
 | 
			
		||||
    else:
 | 
			
		||||
        query_states = self.q_proj(hidden_states)
 | 
			
		||||
        key_states = self.k_proj(hidden_states)
 | 
			
		||||
        value_states = self.v_proj(hidden_states)
 | 
			
		||||
 | 
			
		||||
    query_states = query_states.view(bsz, q_len,
 | 
			
		||||
                                     self.num_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    key_states = key_states.view(bsz, q_len,
 | 
			
		||||
                                 self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
    value_states = value_states.view(bsz, q_len,
 | 
			
		||||
                                     self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value)
 | 
			
		||||
 | 
			
		||||
    # TODO: fuse_rope
 | 
			
		||||
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
                                                    cos, sin, position_ids, "llama")
 | 
			
		||||
 | 
			
		||||
    updated_past_key_values = []
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        batched_attention_output = []
 | 
			
		||||
        # print(f"type of attention_mask is {type(attention_mask)}")
 | 
			
		||||
        for batch in range(bsz):
 | 
			
		||||
            past_k, past_v = past_key_value[batch]
 | 
			
		||||
            current_kv_len = past_k.shape[-2] + 1
 | 
			
		||||
 | 
			
		||||
            current_key_states = torch.cat([past_k,
 | 
			
		||||
                                            key_states[batch: batch + 1, :, :, :]], dim=2)
 | 
			
		||||
            current_value_states = torch.cat([past_v,
 | 
			
		||||
                                              value_states[batch: batch + 1, :, :, :]], dim=2)
 | 
			
		||||
 | 
			
		||||
            updated_past_key_values.append((current_key_states, current_value_states))
 | 
			
		||||
 | 
			
		||||
            current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
 | 
			
		||||
            current_value_states = repeat_kv(current_value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
            current_query_states = query_states[batch: batch + 1, :, :, :]
 | 
			
		||||
            attn_output, attn_weights = native_sdp(current_query_states,
 | 
			
		||||
                                                   current_key_states,
 | 
			
		||||
                                                   current_value_states,
 | 
			
		||||
                                                   attention_mask[batch],
 | 
			
		||||
                                                   1,
 | 
			
		||||
                                                   1,
 | 
			
		||||
                                                   current_kv_len,
 | 
			
		||||
                                                   self.head_dim,
 | 
			
		||||
                                                   self.num_heads)
 | 
			
		||||
            if attn_output.size() != (1, self.num_heads, 1, self.head_dim):
 | 
			
		||||
                invalidInputError(False,
 | 
			
		||||
                                  f"`attn_output` should be of size "
 | 
			
		||||
                                  f"{(1, self.num_heads, 1, self.head_dim)}, but is"
 | 
			
		||||
                                  f" {attn_output.size()}")
 | 
			
		||||
            batched_attention_output.append(attn_output)
 | 
			
		||||
        # For loop ends
 | 
			
		||||
        # TODO: handle attention_weights later
 | 
			
		||||
        attn_output = torch.concat(batched_attention_output, dim=0)
 | 
			
		||||
        batched_attention_output.clear()
 | 
			
		||||
        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              f"`attn_output` should be of size "
 | 
			
		||||
                              f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
 | 
			
		||||
                              f" {attn_output.size()}")
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
        attn_output = self.o_proj(attn_output)
 | 
			
		||||
        return attn_output, None, updated_past_key_values
 | 
			
		||||
 | 
			
		||||
    # TODO: Assume always use_cache
 | 
			
		||||
    # print(f"prefill with batch size {bsz}")
 | 
			
		||||
    for batch in range(bsz):
 | 
			
		||||
        updated_past_key_values.append((key_states[batch: batch + 1, :, :, :],
 | 
			
		||||
                                        value_states[batch: batch+1, :, :, :]))
 | 
			
		||||
 | 
			
		||||
    # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
    key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
 | 
			
		||||
                                                                     dtype=attention_dtype)
 | 
			
		||||
    value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
 | 
			
		||||
                                                                         dtype=attention_dtype)
 | 
			
		||||
    attn_output, attn_weights = native_sdp(query_states,
 | 
			
		||||
                                           key_states,
 | 
			
		||||
                                           value_states,
 | 
			
		||||
                                           attention_mask,
 | 
			
		||||
                                           bsz,
 | 
			
		||||
                                           q_len,
 | 
			
		||||
                                           kv_seq_len,
 | 
			
		||||
                                           self.head_dim,
 | 
			
		||||
                                           self.num_heads)
 | 
			
		||||
 | 
			
		||||
    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"`attn_output` should be of size "
 | 
			
		||||
                          f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
 | 
			
		||||
                          f" {attn_output.size()}")
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
 | 
			
		||||
    attn_output = self.o_proj(attn_output)
 | 
			
		||||
    return attn_output.to(original_dtype), attn_weights, updated_past_key_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_flash_attention_available(query):
 | 
			
		||||
    # check whether ipex flash attention can be used
 | 
			
		||||
    if query.device.type != "xpu":
 | 
			
		||||
| 
						 | 
				
			
			@ -371,3 +533,171 @@ def native_sdp(query, key, value, attention_mask,
 | 
			
		|||
                                         dtype=torch.float32).to(value.dtype)
 | 
			
		||||
    attn_output = torch.matmul(attn_weights, value)
 | 
			
		||||
    return attn_output, attn_weights
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_model_selective_batching_forward_4_31(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids: torch.LongTensor = None,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    use_cache: Optional[bool] = None,
 | 
			
		||||
    output_attentions: Optional[bool] = None,
 | 
			
		||||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
    if output_attentions is not None:
 | 
			
		||||
        output_attentions = output_attentions
 | 
			
		||||
    else:
 | 
			
		||||
        output_attentions = self.config.output_attentions
 | 
			
		||||
    output_hidden_states = (
 | 
			
		||||
        output_hidden_states if output_hidden_states is not None
 | 
			
		||||
        else self.config.output_hidden_states
 | 
			
		||||
    )
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
 | 
			
		||||
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
    # retrieve input_ids and inputs_embeds
 | 
			
		||||
    if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          "You cannot specify both decoder_input_ids"
 | 
			
		||||
                          " and decoder_inputs_embeds at the same time")
 | 
			
		||||
    elif input_ids is not None:
 | 
			
		||||
        batch_size, seq_length = input_ids.shape
 | 
			
		||||
    elif inputs_embeds is not None:
 | 
			
		||||
        batch_size, seq_length, _ = inputs_embeds.shape
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          "You have to specify either "
 | 
			
		||||
                          "decoder_input_ids or decoder_inputs_embeds")
 | 
			
		||||
 | 
			
		||||
    # seq_length_with_past = seq_length
 | 
			
		||||
    past_key_values_length = 0
 | 
			
		||||
 | 
			
		||||
    # The original position_ids in the format of [1, 1]
 | 
			
		||||
    # However, this only applies when kv_len is the same for all the sequences
 | 
			
		||||
    # We should set it to format of [batch, position_id]
 | 
			
		||||
    # TODO: validate correctness
 | 
			
		||||
    device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
			
		||||
    if position_ids is None:
 | 
			
		||||
        invalidInputError("vLLM: position_ids should never be None")
 | 
			
		||||
    else:
 | 
			
		||||
        # print(f"Original position_ids is {position_ids}")
 | 
			
		||||
        position_ids = position_ids.view(-1, seq_length)
 | 
			
		||||
        # print(f"after position_ids is {position_ids}")
 | 
			
		||||
    # if past_key_values is None:
 | 
			
		||||
    #     # For prefill
 | 
			
		||||
    #     position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
 | 
			
		||||
    #     position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
 | 
			
		||||
    # else:
 | 
			
		||||
    #     past_key_values_length = []
 | 
			
		||||
    #     for sequence_kv in past_key_values[0]:
 | 
			
		||||
    #         key = sequence_kv[0]
 | 
			
		||||
    #         past_key_values_length.append(key.shape[-2])
 | 
			
		||||
    #     position_ids = torch.tensor(past_key_values_length, dtype=torch.long, device=device)
 | 
			
		||||
    #     position_ids = position_ids.unsqueeze(0).view(-1, 1)
 | 
			
		||||
 | 
			
		||||
    if past_key_values is not None:
 | 
			
		||||
        # past_key_values in the format of num_layers x num_seqs x 2
 | 
			
		||||
        # TODO: this may be incorrect
 | 
			
		||||
        past_key_values_length = past_key_values[0][0][0].shape[2]
 | 
			
		||||
        # seq_length_with_past = seq_length_with_past + past_key_values_length
 | 
			
		||||
 | 
			
		||||
    # if position_ids is None:
 | 
			
		||||
    #     device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
			
		||||
    #     # [start, end)
 | 
			
		||||
    #     position_ids = torch.arange(
 | 
			
		||||
    #         past_key_values_length, seq_length +
 | 
			
		||||
    #         past_key_values_length, dtype=torch.long, device=device
 | 
			
		||||
    #     )
 | 
			
		||||
    #     position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
 | 
			
		||||
    # else:
 | 
			
		||||
    #     position_ids = position_ids.view(-1, seq_length).long()
 | 
			
		||||
 | 
			
		||||
    if inputs_embeds is None:
 | 
			
		||||
        inputs_embeds = self.embed_tokens(input_ids)
 | 
			
		||||
    # embed positions
 | 
			
		||||
    if attention_mask is None:
 | 
			
		||||
        invalidInputError(False, "attention_mask should never be None")
 | 
			
		||||
    # print(f"attention_mask before expanding: {attention_mask}")
 | 
			
		||||
    if past_key_values is None:
 | 
			
		||||
        attention_mask = self._prepare_decoder_attention_mask(
 | 
			
		||||
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        i = 0
 | 
			
		||||
        for attn_mask in attention_mask:
 | 
			
		||||
            past_key_value_length = past_key_values[0][i][0].shape[2]
 | 
			
		||||
            new_mask = self._prepare_decoder_attention_mask(
 | 
			
		||||
                attn_mask, (1, seq_length), inputs_embeds, past_key_value_length
 | 
			
		||||
            )
 | 
			
		||||
            attention_mask[i] = new_mask
 | 
			
		||||
            i += 1
 | 
			
		||||
 | 
			
		||||
    hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
    if self.gradient_checkpointing and self.training:
 | 
			
		||||
        invalidInputError(False, "gradient_checkpointing is not supported")
 | 
			
		||||
 | 
			
		||||
    # decoder layers
 | 
			
		||||
    all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
    all_self_attns = () if output_attentions else None
 | 
			
		||||
    next_decoder_cache = () if use_cache else None
 | 
			
		||||
 | 
			
		||||
    for idx, decoder_layer in enumerate(self.layers):
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        past_key_value = past_key_values[idx] if past_key_values is not None else None
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
 | 
			
		||||
            def create_custom_forward(module):
 | 
			
		||||
                def custom_forward(*inputs):
 | 
			
		||||
                    # None for past_key_value
 | 
			
		||||
                    return module(*inputs, output_attentions, None)
 | 
			
		||||
 | 
			
		||||
                return custom_forward
 | 
			
		||||
 | 
			
		||||
            layer_outputs = torch.utils.checkpoint.checkpoint(
 | 
			
		||||
                create_custom_forward(decoder_layer),
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                attention_mask,
 | 
			
		||||
                position_ids,
 | 
			
		||||
                None,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            layer_outputs = decoder_layer(
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                position_ids=position_ids,
 | 
			
		||||
                past_key_value=past_key_value,
 | 
			
		||||
                output_attentions=output_attentions,
 | 
			
		||||
                use_cache=use_cache,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        hidden_states = layer_outputs[0]
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
 | 
			
		||||
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            all_self_attns += (layer_outputs[1],)
 | 
			
		||||
 | 
			
		||||
    hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
    # add hidden states from the last decoder layer
 | 
			
		||||
    if output_hidden_states:
 | 
			
		||||
        all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
    next_cache = next_decoder_cache if use_cache else None
 | 
			
		||||
    if not return_dict:
 | 
			
		||||
        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)  # noqa
 | 
			
		||||
    return BaseModelOutputWithPast(
 | 
			
		||||
        last_hidden_state=hidden_states,
 | 
			
		||||
        past_key_values=next_cache,
 | 
			
		||||
        hidden_states=all_hidden_states,
 | 
			
		||||
        attentions=all_self_attns,
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,6 +27,7 @@ from bigdl.llm.vllm.logger import init_logger
 | 
			
		|||
import math
 | 
			
		||||
import time
 | 
			
		||||
from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
 | 
			
		||||
import os
 | 
			
		||||
from transformers.generation.logits_process import (
 | 
			
		||||
    LogitsProcessorList,
 | 
			
		||||
    RepetitionPenaltyLogitsProcessor,
 | 
			
		||||
| 
						 | 
				
			
			@ -50,6 +51,10 @@ def _get_attention_mask_for_prompts(
 | 
			
		|||
    ]
 | 
			
		||||
    return attention_mask
 | 
			
		||||
 | 
			
		||||
vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
 | 
			
		||||
enable_vllm_se_batching = vllm_selective_batching is not None
 | 
			
		||||
enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -61,12 +66,9 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
			
		|||
    ):
 | 
			
		||||
        super().__init__(config, device, max_model_len)
 | 
			
		||||
        self.config = config
 | 
			
		||||
        # TODO(gc): later change this to a switch?
 | 
			
		||||
        if True:
 | 
			
		||||
            from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
            from bigdl.llm import optimize_model
 | 
			
		||||
 | 
			
		||||
        # low_bit = 'sym_int4'
 | 
			
		||||
        # Always enable bigdl-llm model
 | 
			
		||||
        from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
        from bigdl.llm import optimize_model
 | 
			
		||||
        if device == 'cpu':
 | 
			
		||||
            model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
                config._name_or_path,
 | 
			
		||||
| 
						 | 
				
			
			@ -81,7 +83,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
			
		|||
                import intel_extension_for_pytorch as ipex
 | 
			
		||||
            except ImportError:
 | 
			
		||||
                print("Intel Extension for PyTorch is not installed, \
 | 
			
		||||
                       but is required for xpu inference.")
 | 
			
		||||
                    but is required for xpu inference.")
 | 
			
		||||
 | 
			
		||||
            low_bit = 'sym_int4'
 | 
			
		||||
            model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
| 
						 | 
				
			
			@ -93,17 +95,19 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
			
		|||
            self.model = model.to('xpu')
 | 
			
		||||
            self.sampler = BigDLSampler(config.vocab_size, device).to('xpu')
 | 
			
		||||
 | 
			
		||||
        if device is None:
 | 
			
		||||
            self.device = torch.device(
 | 
			
		||||
                "cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
        else:
 | 
			
		||||
            self.device = torch.device(device)
 | 
			
		||||
        self.device = torch.device(device)
 | 
			
		||||
        self.dtype = self.model.dtype
 | 
			
		||||
        self.last_seq_ids = []
 | 
			
		||||
        self.tmp_kv_cache = None
 | 
			
		||||
        self.last_kv_cache = None
 | 
			
		||||
        self.pad_token_id = config.pad_token_id
 | 
			
		||||
        self.max_seq_limit = max_model_len
 | 
			
		||||
 | 
			
		||||
    # GC: Note for selective batching
 | 
			
		||||
    # KV_CACHE in the format of num_layers x 2 x (seq_id -> torch.Tensor)
 | 
			
		||||
    # past_key_values in the format of num_layers x len(seq_id) x (2 x torch.Tensor)
 | 
			
		||||
    # If we set num_layers to 9, have 10 sequences in total.
 | 
			
		||||
    # then, for the kv_cache, we get 9 x 2 x 10 = 180 tensors
 | 
			
		||||
    # for past_key_values, we get 9 x 10 x 2 = 180 tensors
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        seq_group_meta_data_lists: List[SequenceGroupMetadata],
 | 
			
		||||
| 
						 | 
				
			
			@ -116,7 +120,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
			
		|||
        decoder_kv_size = 2
 | 
			
		||||
 | 
			
		||||
        bigdl_input_ids = []
 | 
			
		||||
        bigdl_position_ids = []
 | 
			
		||||
        # bigdl_position_ids = []
 | 
			
		||||
        bigdl_attention_mask = []
 | 
			
		||||
 | 
			
		||||
        cur_seq_ids = []
 | 
			
		||||
| 
						 | 
				
			
			@ -144,8 +148,12 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
			
		|||
        # 1. Assemble bigdl_input_ids end
 | 
			
		||||
 | 
			
		||||
        if is_decoding_stage:
 | 
			
		||||
            bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists,
 | 
			
		||||
                                                   kv_cache, num_layers, decoder_kv_size)
 | 
			
		||||
            construct_kv_cache_func = self.get_construct_kv_cache_func(enable_vllm_se_batching)
 | 
			
		||||
            bigdl_kv_cache = construct_kv_cache_func(cur_seq_ids,
 | 
			
		||||
                                                     seq_group_meta_data_lists,
 | 
			
		||||
                                                     kv_cache,
 | 
			
		||||
                                                     num_layers,
 | 
			
		||||
                                                     2)
 | 
			
		||||
        else:
 | 
			
		||||
            bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len)
 | 
			
		||||
            bigdl_input_ids = [
 | 
			
		||||
| 
						 | 
				
			
			@ -153,41 +161,72 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
			
		|||
                for input_ids in bigdl_input_ids
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
        decoding_attention_mask_list = []
 | 
			
		||||
        decoding_position_ids = []
 | 
			
		||||
        # num_layers x len(seq_id) x (2 x torch.Tensor)
 | 
			
		||||
        if is_decoding_stage:
 | 
			
		||||
            cur_seq_len = bigdl_kv_cache[0][0].size(2)
 | 
			
		||||
            for seq_group_meta_data in seq_group_meta_data_lists:
 | 
			
		||||
                seq_ids = list(seq_group_meta_data.seq_data.keys())
 | 
			
		||||
                seq_id = seq_ids[0]
 | 
			
		||||
                seq_data = seq_group_meta_data.seq_data[seq_id]
 | 
			
		||||
                cur_pos = seq_data.get_len()
 | 
			
		||||
                # bigdl_position_ids.append([cur_pos - 1])
 | 
			
		||||
                cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
 | 
			
		||||
                bigdl_attention_mask.append(cur_attention_mask)
 | 
			
		||||
            if enable_vllm_se_batching:
 | 
			
		||||
                batch = 0
 | 
			
		||||
                for seq_group_meta_data in seq_group_meta_data_lists:
 | 
			
		||||
                    # Get current seq_len in kv_cache
 | 
			
		||||
                    current_seq_len = bigdl_kv_cache[0][batch][0].size(2)
 | 
			
		||||
                    batch += 1
 | 
			
		||||
                    seq_ids = list(seq_group_meta_data.seq_data.keys())
 | 
			
		||||
                    seq_data = seq_group_meta_data.seq_data[seq_ids[0]]
 | 
			
		||||
                    cur_pos = seq_data.get_len()
 | 
			
		||||
                    decoding_position_ids.append(cur_pos - 1)
 | 
			
		||||
                    # Total length: current_seq_len + 1
 | 
			
		||||
                    cur_attention_mask = [0] * (current_seq_len - cur_pos + 1) + [1] * (cur_pos)
 | 
			
		||||
                    decoding_attention_mask_list.append(cur_attention_mask)
 | 
			
		||||
            else:
 | 
			
		||||
                cur_seq_len = bigdl_kv_cache[0][0].size(2)
 | 
			
		||||
                for seq_group_meta_data in seq_group_meta_data_lists:
 | 
			
		||||
                    seq_ids = list(seq_group_meta_data.seq_data.keys())
 | 
			
		||||
                    seq_id = seq_ids[0]
 | 
			
		||||
                    seq_data = seq_group_meta_data.seq_data[seq_id]
 | 
			
		||||
                    cur_pos = seq_data.get_len()
 | 
			
		||||
                    # bigdl_position_ids.append([cur_pos - 1])
 | 
			
		||||
                    # decoding_position_ids.append(cur_pos - 1)
 | 
			
		||||
                    cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
 | 
			
		||||
                    decoding_attention_mask_list.append(cur_attention_mask)
 | 
			
		||||
 | 
			
		||||
        bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
 | 
			
		||||
 | 
			
		||||
        if is_decoding_stage:
 | 
			
		||||
            # bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
 | 
			
		||||
            bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
 | 
			
		||||
            if enable_vllm_se_batching:
 | 
			
		||||
                attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0)
 | 
			
		||||
                                  for x in decoding_attention_mask_list]
 | 
			
		||||
                position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1)
 | 
			
		||||
            else:
 | 
			
		||||
                attention_mask = torch.tensor(decoding_attention_mask_list, device=self.device)
 | 
			
		||||
                position_ids = None
 | 
			
		||||
            kwargs = {
 | 
			
		||||
                "input_ids": bigdl_input_ids,
 | 
			
		||||
                # "position_ids": bigdl_position_ids,
 | 
			
		||||
                "attention_mask": bigdl_attention_mask,
 | 
			
		||||
                "position_ids": position_ids,
 | 
			
		||||
                "attention_mask": attention_mask,
 | 
			
		||||
                "past_key_values": bigdl_kv_cache,
 | 
			
		||||
                "use_cache": True,
 | 
			
		||||
                # "return_dict": True,
 | 
			
		||||
            }
 | 
			
		||||
        else:
 | 
			
		||||
            # Prefill stage
 | 
			
		||||
            attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
 | 
			
		||||
            if enable_vllm_se_batching:
 | 
			
		||||
                position_ids = attention_mask.long().cumsum(-1) - 1
 | 
			
		||||
                position_ids.masked_fill_(attention_mask == 0, 1)
 | 
			
		||||
            else:
 | 
			
		||||
                position_ids = None
 | 
			
		||||
            kwargs = {
 | 
			
		||||
                "input_ids": bigdl_input_ids,
 | 
			
		||||
                "attention_mask": torch.tensor(bigdl_attention_mask, device=self.device),
 | 
			
		||||
                # "position_ids": bigdl_position_ids,
 | 
			
		||||
                "attention_mask": attention_mask,
 | 
			
		||||
                "position_ids": position_ids,
 | 
			
		||||
                "past_key_values": None,
 | 
			
		||||
                "use_cache": True,
 | 
			
		||||
                # "return_dict": True,
 | 
			
		||||
            }
 | 
			
		||||
            # Prefill may need additional space, which forces us to delete the last_kv_cache
 | 
			
		||||
            if self.last_kv_cache:
 | 
			
		||||
                del self.last_kv_cache
 | 
			
		||||
                self.last_kv_cache = None
 | 
			
		||||
        # pdb.set_trace()
 | 
			
		||||
 | 
			
		||||
        if self.device.type == 'xpu':
 | 
			
		||||
| 
						 | 
				
			
			@ -207,8 +246,12 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
			
		|||
        # tmp = torch.xpu.memory_stats()
 | 
			
		||||
        # logger.info(f"before: {tmp['allocated_bytes.all.current']}")
 | 
			
		||||
 | 
			
		||||
        self.update_kv_cache(cur_seq_ids,
 | 
			
		||||
                             kv_cache, num_layers, decoder_kv_size)
 | 
			
		||||
        if enable_vllm_se_batching:
 | 
			
		||||
            self.update_kv_cache_selective_batching(
 | 
			
		||||
                cur_seq_ids, kv_cache, num_layers, decoder_kv_size)
 | 
			
		||||
            self.last_kv_cache = None
 | 
			
		||||
        else:
 | 
			
		||||
            self.update_kv_cache(cur_seq_ids, kv_cache, num_layers, decoder_kv_size)
 | 
			
		||||
 | 
			
		||||
        # tmp = torch.xpu.memory_stats()
 | 
			
		||||
        # logger.info(f"after: {tmp['allocated_bytes.all.current']}")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -137,6 +137,34 @@ class BigDLModelForCausalLM(nn.Module):
 | 
			
		|||
 | 
			
		||||
        return bigdl_kv_cache
 | 
			
		||||
 | 
			
		||||
    def get_construct_kv_cache_func(self, enable_selective_batching):
 | 
			
		||||
        if enable_selective_batching:
 | 
			
		||||
            return self.prepare_kv_cache_selective_batching
 | 
			
		||||
        else:
 | 
			
		||||
            return self.prepare_kv_cache
 | 
			
		||||
 | 
			
		||||
    # This is an implementation for models that KV Cache shape in (batch_size, num_heads,
 | 
			
		||||
    # sequence_length, embed_size_per_head).
 | 
			
		||||
    def prepare_kv_cache_selective_batching(
 | 
			
		||||
        self,
 | 
			
		||||
        cur_seq_ids: List[int],
 | 
			
		||||
        seq_group_meta_data_lists: List[SequenceGroupMetadata],
 | 
			
		||||
        kv_cache: Dict,
 | 
			
		||||
        num_layers: int,
 | 
			
		||||
        kv_cache_size_1: int,
 | 
			
		||||
    ):
 | 
			
		||||
        # Return bigdl_kv_cache in the format of Tuple(List[Tuple(torch.Tensor)])
 | 
			
		||||
        bigdl_kv_cache = []
 | 
			
		||||
        for i in range(num_layers):
 | 
			
		||||
            # Construct a list of tuple(tensor)
 | 
			
		||||
            temp_cache = []
 | 
			
		||||
            for seq_id in cur_seq_ids:
 | 
			
		||||
                key = kv_cache[i][0][seq_id]
 | 
			
		||||
                value = kv_cache[i][1][seq_id]
 | 
			
		||||
                temp_cache.append((key, value))
 | 
			
		||||
            bigdl_kv_cache.append(temp_cache)
 | 
			
		||||
        return bigdl_kv_cache
 | 
			
		||||
 | 
			
		||||
    # This is an implementation for models that KV Cache shape in (batch_size, num_heads,
 | 
			
		||||
    # sequence_length, embed_size_per_head).
 | 
			
		||||
    def update_kv_cache(
 | 
			
		||||
| 
						 | 
				
			
			@ -153,6 +181,18 @@ class BigDLModelForCausalLM(nn.Module):
 | 
			
		|||
                    kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim]
 | 
			
		||||
                    batch_dim = batch_dim + 1
 | 
			
		||||
 | 
			
		||||
    def update_kv_cache_selective_batching(
 | 
			
		||||
        self,
 | 
			
		||||
        cur_seq_ids: List[int],
 | 
			
		||||
        kv_cache,
 | 
			
		||||
        layer: int,
 | 
			
		||||
        kv_cache_size_1: int,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        for i in range(layer):
 | 
			
		||||
            for j in range(len(cur_seq_ids)):
 | 
			
		||||
                kv_cache[i][0][cur_seq_ids[j]] = self.last_kv_cache[i][j][0]
 | 
			
		||||
                kv_cache[i][1][cur_seq_ids[j]] = self.last_kv_cache[i][j][1]
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        seq_group_meta_data_lists: List[SequenceGroupMetadata],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue