Implement selective batching for vLLM (#9659)
* add control to load hf model * finish initial version of selective_batching * temp * finish * Remove print statement * fix error * Apply yang's optimization * a version that works * We need to check kv_cache passed in, this could be an error. TODO: add fast decoding path * format * temp solution: not batching prefill requests * a version that works for prefill batching * format * a solid version: works normally * a temp version * Solid version: remove redundant functions * fix format * format * solid: add option to enable selective_batching * remove logic for using transformer models * format * format * solid: enable argument VLLM_ENABLE_SELECTIVE_BATCHING * format * finish * format
This commit is contained in:
		
							parent
							
								
									2f36769208
								
							
						
					
					
						commit
						fdf93c9267
					
				
					 4 changed files with 467 additions and 36 deletions
				
			
		| 
						 | 
					@ -46,6 +46,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
from .utils import logger
 | 
					from .utils import logger
 | 
				
			||||||
from typing import Union
 | 
					from typing import Union
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -386,6 +387,8 @@ def convert_forward(m, target_m, new_forward):
 | 
				
			||||||
def _optimize_post(model, lightweight_bmm=False):
 | 
					def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
    from packaging import version
 | 
					    from packaging import version
 | 
				
			||||||
    from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
 | 
					    from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
 | 
				
			||||||
 | 
					    from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31
 | 
				
			||||||
 | 
					    from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31
 | 
				
			||||||
    from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
 | 
					    from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
 | 
				
			||||||
    from bigdl.llm.transformers.models.llama import llama_mlp_forward
 | 
					    from bigdl.llm.transformers.models.llama import llama_mlp_forward
 | 
				
			||||||
    from transformers.modeling_utils import PreTrainedModel
 | 
					    from transformers.modeling_utils import PreTrainedModel
 | 
				
			||||||
| 
						 | 
					@ -396,6 +399,10 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
                    "supported for further optimizations")
 | 
					                    "supported for further optimizations")
 | 
				
			||||||
        return model
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
 | 
				
			||||||
 | 
					    enable_vllm_se_batching = vllm_selective_batching is not None
 | 
				
			||||||
 | 
					    enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    trans_version = transformers.__version__
 | 
					    trans_version = transformers.__version__
 | 
				
			||||||
    if version.parse(trans_version) >= version.parse("4.31.0"):
 | 
					    if version.parse(trans_version) >= version.parse("4.31.0"):
 | 
				
			||||||
        convert_forward(
 | 
					        convert_forward(
 | 
				
			||||||
| 
						 | 
					@ -409,6 +416,17 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        convert_forward(model,
 | 
					        convert_forward(model,
 | 
				
			||||||
                        transformers.models.llama.modeling_llama.LlamaMLP,
 | 
					                        transformers.models.llama.modeling_llama.LlamaMLP,
 | 
				
			||||||
                        llama_mlp_forward)
 | 
					                        llama_mlp_forward)
 | 
				
			||||||
 | 
					        if enable_vllm_se_batching:
 | 
				
			||||||
 | 
					            convert_forward(
 | 
				
			||||||
 | 
					                model,
 | 
				
			||||||
 | 
					                transformers.models.llama.modeling_llama.LlamaModel,
 | 
				
			||||||
 | 
					                llama_model_selective_batching_forward_4_31,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            convert_forward(
 | 
				
			||||||
 | 
					                model,
 | 
				
			||||||
 | 
					                transformers.models.llama.modeling_llama.LlamaAttention,
 | 
				
			||||||
 | 
					                llama_attention_selective_batching_forward_4_31,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        # todo implement 4.28.0 ~ 4.30.2
 | 
					        # todo implement 4.28.0 ~ 4.30.2
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,15 +34,17 @@
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import importlib
 | 
					import importlib
 | 
				
			||||||
import torch.nn as nn
 | 
					import torch.nn as nn
 | 
				
			||||||
from typing import Optional, Tuple
 | 
					from typing import Optional, Tuple, Union, List
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb
 | 
					from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
					from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
				
			||||||
 | 
					from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
				
			||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
 | 
					from bigdl.llm.transformers.low_bit_linear import SYM_INT4
 | 
				
			||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
					from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
					def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
				
			||||||
| 
						 | 
					@ -191,7 +193,6 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
            value_states = [F.linear(hidden_states, value_slices[i])
 | 
					            value_states = [F.linear(hidden_states, value_slices[i])
 | 
				
			||||||
                            for i in range(self.config.pretraining_tp)]
 | 
					                            for i in range(self.config.pretraining_tp)]
 | 
				
			||||||
            value_states = torch.cat(value_states, dim=-1)
 | 
					            value_states = torch.cat(value_states, dim=-1)
 | 
				
			||||||
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            query_states = self.q_proj(hidden_states)
 | 
					            query_states = self.q_proj(hidden_states)
 | 
				
			||||||
            key_states = self.k_proj(hidden_states)
 | 
					            key_states = self.k_proj(hidden_states)
 | 
				
			||||||
| 
						 | 
					@ -305,6 +306,167 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
    return attn_output.to(original_dtype), attn_weights, past_key_value
 | 
					    return attn_output.to(original_dtype), attn_weights, past_key_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def llama_attention_selective_batching_forward_4_31(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_value: Optional[Tuple[torch.Tensor]] = None,
 | 
				
			||||||
 | 
					    output_attentions: bool = False,
 | 
				
			||||||
 | 
					    use_cache: bool = False,
 | 
				
			||||||
 | 
					    padding_mask: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    **kwargs,
 | 
				
			||||||
 | 
					) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
				
			||||||
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					    # for flash attention
 | 
				
			||||||
 | 
					    original_dtype = hidden_states.dtype
 | 
				
			||||||
 | 
					    # TODO: consider this later - flash attention
 | 
				
			||||||
 | 
					    # if not self.training and not hidden_states.requires_grad:
 | 
				
			||||||
 | 
					    #     fsdp_flag = check_flash_attention_available(hidden_states)
 | 
				
			||||||
 | 
					    # else:
 | 
				
			||||||
 | 
					    #     fsdp_flag = False
 | 
				
			||||||
 | 
					    # if fsdp_flag and q_len > 1:
 | 
				
			||||||
 | 
					    #     attention_dtype = torch.float16  # use fp16 for flash attention
 | 
				
			||||||
 | 
					    # else:
 | 
				
			||||||
 | 
					    #     attention_dtype = original_dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attention_dtype = original_dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # TODO: decoding fast path
 | 
				
			||||||
 | 
					    # use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
 | 
					    # enough_kv_room = is_enough_kv_cache_room(past_key_value[0])
 | 
				
			||||||
 | 
					    # is_q4_0 = self.q_proj.qtype == SYM_INT4
 | 
				
			||||||
 | 
					    # no_tp = not self.config.pretraining_tp > 1
 | 
				
			||||||
 | 
					    # decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
 | 
				
			||||||
 | 
					    #                       enough_kv_room and bsz * q_len == 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # single batch decoding fast path
 | 
				
			||||||
 | 
					    # forward_qkv takes will perform QKV projection, rotary position embedding
 | 
				
			||||||
 | 
					    # and save the key/value states to cache, then return query states and the
 | 
				
			||||||
 | 
					    # extended key/value cache
 | 
				
			||||||
 | 
					    # if decoding_fast_path:
 | 
				
			||||||
 | 
					    #     hidden_states = hidden_states.view(1, -1)
 | 
				
			||||||
 | 
					    #     kv_seq_len = past_key_value[0].shape[-2]
 | 
				
			||||||
 | 
					    #     cache_k = past_key_value[0]
 | 
				
			||||||
 | 
					    #     cache_v = past_key_value[1]
 | 
				
			||||||
 | 
					    #     import linear_q4_0
 | 
				
			||||||
 | 
					    #     query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
 | 
				
			||||||
 | 
					    #                                                                      self.q_proj.weight,
 | 
				
			||||||
 | 
					    #                                                                      self.k_proj.weight,
 | 
				
			||||||
 | 
					    #                                                                      self.v_proj.weight,
 | 
				
			||||||
 | 
					    #                                                                      position_ids,
 | 
				
			||||||
 | 
					    #                                                                      cache_k, cache_v,
 | 
				
			||||||
 | 
					    #                                                                      self.q_proj.weight.qtype,
 | 
				
			||||||
 | 
					    #                                                                      kv_seq_len,
 | 
				
			||||||
 | 
					    #                                                                      self.head_dim)
 | 
				
			||||||
 | 
					    #     kv_seq_len += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # else:
 | 
				
			||||||
 | 
					    if self.config.pretraining_tp > 1:
 | 
				
			||||||
 | 
					        invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet")
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        query_states = self.q_proj(hidden_states)
 | 
				
			||||||
 | 
					        key_states = self.k_proj(hidden_states)
 | 
				
			||||||
 | 
					        value_states = self.v_proj(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    query_states = query_states.view(bsz, q_len,
 | 
				
			||||||
 | 
					                                     self.num_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					    key_states = key_states.view(bsz, q_len,
 | 
				
			||||||
 | 
					                                 self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					    value_states = value_states.view(bsz, q_len,
 | 
				
			||||||
 | 
					                                     self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
 | 
					    if past_key_value is not None:
 | 
				
			||||||
 | 
					        kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # TODO: fuse_rope
 | 
				
			||||||
 | 
					    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
				
			||||||
 | 
					    query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
				
			||||||
 | 
					                                                    cos, sin, position_ids, "llama")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    updated_past_key_values = []
 | 
				
			||||||
 | 
					    if past_key_value is not None:
 | 
				
			||||||
 | 
					        batched_attention_output = []
 | 
				
			||||||
 | 
					        # print(f"type of attention_mask is {type(attention_mask)}")
 | 
				
			||||||
 | 
					        for batch in range(bsz):
 | 
				
			||||||
 | 
					            past_k, past_v = past_key_value[batch]
 | 
				
			||||||
 | 
					            current_kv_len = past_k.shape[-2] + 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            current_key_states = torch.cat([past_k,
 | 
				
			||||||
 | 
					                                            key_states[batch: batch + 1, :, :, :]], dim=2)
 | 
				
			||||||
 | 
					            current_value_states = torch.cat([past_v,
 | 
				
			||||||
 | 
					                                              value_states[batch: batch + 1, :, :, :]], dim=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            updated_past_key_values.append((current_key_states, current_value_states))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
 | 
				
			||||||
 | 
					            current_value_states = repeat_kv(current_value_states, self.num_key_value_groups)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            current_query_states = query_states[batch: batch + 1, :, :, :]
 | 
				
			||||||
 | 
					            attn_output, attn_weights = native_sdp(current_query_states,
 | 
				
			||||||
 | 
					                                                   current_key_states,
 | 
				
			||||||
 | 
					                                                   current_value_states,
 | 
				
			||||||
 | 
					                                                   attention_mask[batch],
 | 
				
			||||||
 | 
					                                                   1,
 | 
				
			||||||
 | 
					                                                   1,
 | 
				
			||||||
 | 
					                                                   current_kv_len,
 | 
				
			||||||
 | 
					                                                   self.head_dim,
 | 
				
			||||||
 | 
					                                                   self.num_heads)
 | 
				
			||||||
 | 
					            if attn_output.size() != (1, self.num_heads, 1, self.head_dim):
 | 
				
			||||||
 | 
					                invalidInputError(False,
 | 
				
			||||||
 | 
					                                  f"`attn_output` should be of size "
 | 
				
			||||||
 | 
					                                  f"{(1, self.num_heads, 1, self.head_dim)}, but is"
 | 
				
			||||||
 | 
					                                  f" {attn_output.size()}")
 | 
				
			||||||
 | 
					            batched_attention_output.append(attn_output)
 | 
				
			||||||
 | 
					        # For loop ends
 | 
				
			||||||
 | 
					        # TODO: handle attention_weights later
 | 
				
			||||||
 | 
					        attn_output = torch.concat(batched_attention_output, dim=0)
 | 
				
			||||||
 | 
					        batched_attention_output.clear()
 | 
				
			||||||
 | 
					        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 | 
				
			||||||
 | 
					            invalidInputError(False,
 | 
				
			||||||
 | 
					                              f"`attn_output` should be of size "
 | 
				
			||||||
 | 
					                              f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
 | 
				
			||||||
 | 
					                              f" {attn_output.size()}")
 | 
				
			||||||
 | 
					        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
				
			||||||
 | 
					        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
				
			||||||
 | 
					        attn_output = self.o_proj(attn_output)
 | 
				
			||||||
 | 
					        return attn_output, None, updated_past_key_values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # TODO: Assume always use_cache
 | 
				
			||||||
 | 
					    # print(f"prefill with batch size {bsz}")
 | 
				
			||||||
 | 
					    for batch in range(bsz):
 | 
				
			||||||
 | 
					        updated_past_key_values.append((key_states[batch: batch + 1, :, :, :],
 | 
				
			||||||
 | 
					                                        value_states[batch: batch+1, :, :, :]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # repeat k/v heads if n_kv_heads < n_heads
 | 
				
			||||||
 | 
					    key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
 | 
				
			||||||
 | 
					                                                                     dtype=attention_dtype)
 | 
				
			||||||
 | 
					    value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
 | 
				
			||||||
 | 
					                                                                         dtype=attention_dtype)
 | 
				
			||||||
 | 
					    attn_output, attn_weights = native_sdp(query_states,
 | 
				
			||||||
 | 
					                                           key_states,
 | 
				
			||||||
 | 
					                                           value_states,
 | 
				
			||||||
 | 
					                                           attention_mask,
 | 
				
			||||||
 | 
					                                           bsz,
 | 
				
			||||||
 | 
					                                           q_len,
 | 
				
			||||||
 | 
					                                           kv_seq_len,
 | 
				
			||||||
 | 
					                                           self.head_dim,
 | 
				
			||||||
 | 
					                                           self.num_heads)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          f"`attn_output` should be of size "
 | 
				
			||||||
 | 
					                          f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
 | 
				
			||||||
 | 
					                          f" {attn_output.size()}")
 | 
				
			||||||
 | 
					    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
				
			||||||
 | 
					    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = self.o_proj(attn_output)
 | 
				
			||||||
 | 
					    return attn_output.to(original_dtype), attn_weights, updated_past_key_values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def check_flash_attention_available(query):
 | 
					def check_flash_attention_available(query):
 | 
				
			||||||
    # check whether ipex flash attention can be used
 | 
					    # check whether ipex flash attention can be used
 | 
				
			||||||
    if query.device.type != "xpu":
 | 
					    if query.device.type != "xpu":
 | 
				
			||||||
| 
						 | 
					@ -371,3 +533,171 @@ def native_sdp(query, key, value, attention_mask,
 | 
				
			||||||
                                         dtype=torch.float32).to(value.dtype)
 | 
					                                         dtype=torch.float32).to(value.dtype)
 | 
				
			||||||
    attn_output = torch.matmul(attn_weights, value)
 | 
					    attn_output = torch.matmul(attn_weights, value)
 | 
				
			||||||
    return attn_output, attn_weights
 | 
					    return attn_output, attn_weights
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def llama_model_selective_batching_forward_4_31(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids: torch.LongTensor = None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
 | 
					) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
				
			||||||
 | 
					    if output_attentions is not None:
 | 
				
			||||||
 | 
					        output_attentions = output_attentions
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        output_attentions = self.config.output_attentions
 | 
				
			||||||
 | 
					    output_hidden_states = (
 | 
				
			||||||
 | 
					        output_hidden_states if output_hidden_states is not None
 | 
				
			||||||
 | 
					        else self.config.output_hidden_states
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # retrieve input_ids and inputs_embeds
 | 
				
			||||||
 | 
					    if input_ids is not None and inputs_embeds is not None:
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          "You cannot specify both decoder_input_ids"
 | 
				
			||||||
 | 
					                          " and decoder_inputs_embeds at the same time")
 | 
				
			||||||
 | 
					    elif input_ids is not None:
 | 
				
			||||||
 | 
					        batch_size, seq_length = input_ids.shape
 | 
				
			||||||
 | 
					    elif inputs_embeds is not None:
 | 
				
			||||||
 | 
					        batch_size, seq_length, _ = inputs_embeds.shape
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          "You have to specify either "
 | 
				
			||||||
 | 
					                          "decoder_input_ids or decoder_inputs_embeds")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # seq_length_with_past = seq_length
 | 
				
			||||||
 | 
					    past_key_values_length = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # The original position_ids in the format of [1, 1]
 | 
				
			||||||
 | 
					    # However, this only applies when kv_len is the same for all the sequences
 | 
				
			||||||
 | 
					    # We should set it to format of [batch, position_id]
 | 
				
			||||||
 | 
					    # TODO: validate correctness
 | 
				
			||||||
 | 
					    device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
				
			||||||
 | 
					    if position_ids is None:
 | 
				
			||||||
 | 
					        invalidInputError("vLLM: position_ids should never be None")
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # print(f"Original position_ids is {position_ids}")
 | 
				
			||||||
 | 
					        position_ids = position_ids.view(-1, seq_length)
 | 
				
			||||||
 | 
					        # print(f"after position_ids is {position_ids}")
 | 
				
			||||||
 | 
					    # if past_key_values is None:
 | 
				
			||||||
 | 
					    #     # For prefill
 | 
				
			||||||
 | 
					    #     position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
 | 
				
			||||||
 | 
					    #     position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
 | 
				
			||||||
 | 
					    # else:
 | 
				
			||||||
 | 
					    #     past_key_values_length = []
 | 
				
			||||||
 | 
					    #     for sequence_kv in past_key_values[0]:
 | 
				
			||||||
 | 
					    #         key = sequence_kv[0]
 | 
				
			||||||
 | 
					    #         past_key_values_length.append(key.shape[-2])
 | 
				
			||||||
 | 
					    #     position_ids = torch.tensor(past_key_values_length, dtype=torch.long, device=device)
 | 
				
			||||||
 | 
					    #     position_ids = position_ids.unsqueeze(0).view(-1, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if past_key_values is not None:
 | 
				
			||||||
 | 
					        # past_key_values in the format of num_layers x num_seqs x 2
 | 
				
			||||||
 | 
					        # TODO: this may be incorrect
 | 
				
			||||||
 | 
					        past_key_values_length = past_key_values[0][0][0].shape[2]
 | 
				
			||||||
 | 
					        # seq_length_with_past = seq_length_with_past + past_key_values_length
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # if position_ids is None:
 | 
				
			||||||
 | 
					    #     device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
				
			||||||
 | 
					    #     # [start, end)
 | 
				
			||||||
 | 
					    #     position_ids = torch.arange(
 | 
				
			||||||
 | 
					    #         past_key_values_length, seq_length +
 | 
				
			||||||
 | 
					    #         past_key_values_length, dtype=torch.long, device=device
 | 
				
			||||||
 | 
					    #     )
 | 
				
			||||||
 | 
					    #     position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
 | 
				
			||||||
 | 
					    # else:
 | 
				
			||||||
 | 
					    #     position_ids = position_ids.view(-1, seq_length).long()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if inputs_embeds is None:
 | 
				
			||||||
 | 
					        inputs_embeds = self.embed_tokens(input_ids)
 | 
				
			||||||
 | 
					    # embed positions
 | 
				
			||||||
 | 
					    if attention_mask is None:
 | 
				
			||||||
 | 
					        invalidInputError(False, "attention_mask should never be None")
 | 
				
			||||||
 | 
					    # print(f"attention_mask before expanding: {attention_mask}")
 | 
				
			||||||
 | 
					    if past_key_values is None:
 | 
				
			||||||
 | 
					        attention_mask = self._prepare_decoder_attention_mask(
 | 
				
			||||||
 | 
					            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        i = 0
 | 
				
			||||||
 | 
					        for attn_mask in attention_mask:
 | 
				
			||||||
 | 
					            past_key_value_length = past_key_values[0][i][0].shape[2]
 | 
				
			||||||
 | 
					            new_mask = self._prepare_decoder_attention_mask(
 | 
				
			||||||
 | 
					                attn_mask, (1, seq_length), inputs_embeds, past_key_value_length
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            attention_mask[i] = new_mask
 | 
				
			||||||
 | 
					            i += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = inputs_embeds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self.gradient_checkpointing and self.training:
 | 
				
			||||||
 | 
					        invalidInputError(False, "gradient_checkpointing is not supported")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # decoder layers
 | 
				
			||||||
 | 
					    all_hidden_states = () if output_hidden_states else None
 | 
				
			||||||
 | 
					    all_self_attns = () if output_attentions else None
 | 
				
			||||||
 | 
					    next_decoder_cache = () if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for idx, decoder_layer in enumerate(self.layers):
 | 
				
			||||||
 | 
					        if output_hidden_states:
 | 
				
			||||||
 | 
					            all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        past_key_value = past_key_values[idx] if past_key_values is not None else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.gradient_checkpointing and self.training:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            def create_custom_forward(module):
 | 
				
			||||||
 | 
					                def custom_forward(*inputs):
 | 
				
			||||||
 | 
					                    # None for past_key_value
 | 
				
			||||||
 | 
					                    return module(*inputs, output_attentions, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                return custom_forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            layer_outputs = torch.utils.checkpoint.checkpoint(
 | 
				
			||||||
 | 
					                create_custom_forward(decoder_layer),
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                attention_mask,
 | 
				
			||||||
 | 
					                position_ids,
 | 
				
			||||||
 | 
					                None,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            layer_outputs = decoder_layer(
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                attention_mask=attention_mask,
 | 
				
			||||||
 | 
					                position_ids=position_ids,
 | 
				
			||||||
 | 
					                past_key_value=past_key_value,
 | 
				
			||||||
 | 
					                output_attentions=output_attentions,
 | 
				
			||||||
 | 
					                use_cache=use_cache,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = layer_outputs[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if use_cache:
 | 
				
			||||||
 | 
					            next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if output_attentions:
 | 
				
			||||||
 | 
					            all_self_attns += (layer_outputs[1],)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = self.norm(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # add hidden states from the last decoder layer
 | 
				
			||||||
 | 
					    if output_hidden_states:
 | 
				
			||||||
 | 
					        all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    next_cache = next_decoder_cache if use_cache else None
 | 
				
			||||||
 | 
					    if not return_dict:
 | 
				
			||||||
 | 
					        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)  # noqa
 | 
				
			||||||
 | 
					    return BaseModelOutputWithPast(
 | 
				
			||||||
 | 
					        last_hidden_state=hidden_states,
 | 
				
			||||||
 | 
					        past_key_values=next_cache,
 | 
				
			||||||
 | 
					        hidden_states=all_hidden_states,
 | 
				
			||||||
 | 
					        attentions=all_self_attns,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -27,6 +27,7 @@ from bigdl.llm.vllm.logger import init_logger
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
 | 
					from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
from transformers.generation.logits_process import (
 | 
					from transformers.generation.logits_process import (
 | 
				
			||||||
    LogitsProcessorList,
 | 
					    LogitsProcessorList,
 | 
				
			||||||
    RepetitionPenaltyLogitsProcessor,
 | 
					    RepetitionPenaltyLogitsProcessor,
 | 
				
			||||||
| 
						 | 
					@ -50,6 +51,10 @@ def _get_attention_mask_for_prompts(
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
    return attention_mask
 | 
					    return attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
 | 
				
			||||||
 | 
					enable_vllm_se_batching = vllm_selective_batching is not None
 | 
				
			||||||
 | 
					enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
					class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -61,12 +66,9 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        super().__init__(config, device, max_model_len)
 | 
					        super().__init__(config, device, max_model_len)
 | 
				
			||||||
        self.config = config
 | 
					        self.config = config
 | 
				
			||||||
        # TODO(gc): later change this to a switch?
 | 
					        # Always enable bigdl-llm model
 | 
				
			||||||
        if True:
 | 
					        from bigdl.llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
            from bigdl.llm.transformers import AutoModelForCausalLM
 | 
					        from bigdl.llm import optimize_model
 | 
				
			||||||
            from bigdl.llm import optimize_model
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # low_bit = 'sym_int4'
 | 
					 | 
				
			||||||
        if device == 'cpu':
 | 
					        if device == 'cpu':
 | 
				
			||||||
            model = AutoModelForCausalLM.from_pretrained(
 | 
					            model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
                config._name_or_path,
 | 
					                config._name_or_path,
 | 
				
			||||||
| 
						 | 
					@ -81,7 +83,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
                import intel_extension_for_pytorch as ipex
 | 
					                import intel_extension_for_pytorch as ipex
 | 
				
			||||||
            except ImportError:
 | 
					            except ImportError:
 | 
				
			||||||
                print("Intel Extension for PyTorch is not installed, \
 | 
					                print("Intel Extension for PyTorch is not installed, \
 | 
				
			||||||
                       but is required for xpu inference.")
 | 
					                    but is required for xpu inference.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            low_bit = 'sym_int4'
 | 
					            low_bit = 'sym_int4'
 | 
				
			||||||
            model = AutoModelForCausalLM.from_pretrained(
 | 
					            model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
| 
						 | 
					@ -93,17 +95,19 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
            self.model = model.to('xpu')
 | 
					            self.model = model.to('xpu')
 | 
				
			||||||
            self.sampler = BigDLSampler(config.vocab_size, device).to('xpu')
 | 
					            self.sampler = BigDLSampler(config.vocab_size, device).to('xpu')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if device is None:
 | 
					        self.device = torch.device(device)
 | 
				
			||||||
            self.device = torch.device(
 | 
					 | 
				
			||||||
                "cuda" if torch.cuda.is_available() else "cpu")
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.device = torch.device(device)
 | 
					 | 
				
			||||||
        self.dtype = self.model.dtype
 | 
					        self.dtype = self.model.dtype
 | 
				
			||||||
        self.last_seq_ids = []
 | 
					        self.last_seq_ids = []
 | 
				
			||||||
        self.tmp_kv_cache = None
 | 
					        self.last_kv_cache = None
 | 
				
			||||||
        self.pad_token_id = config.pad_token_id
 | 
					        self.pad_token_id = config.pad_token_id
 | 
				
			||||||
        self.max_seq_limit = max_model_len
 | 
					        self.max_seq_limit = max_model_len
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # GC: Note for selective batching
 | 
				
			||||||
 | 
					    # KV_CACHE in the format of num_layers x 2 x (seq_id -> torch.Tensor)
 | 
				
			||||||
 | 
					    # past_key_values in the format of num_layers x len(seq_id) x (2 x torch.Tensor)
 | 
				
			||||||
 | 
					    # If we set num_layers to 9, have 10 sequences in total.
 | 
				
			||||||
 | 
					    # then, for the kv_cache, we get 9 x 2 x 10 = 180 tensors
 | 
				
			||||||
 | 
					    # for past_key_values, we get 9 x 10 x 2 = 180 tensors
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        seq_group_meta_data_lists: List[SequenceGroupMetadata],
 | 
					        seq_group_meta_data_lists: List[SequenceGroupMetadata],
 | 
				
			||||||
| 
						 | 
					@ -116,7 +120,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
        decoder_kv_size = 2
 | 
					        decoder_kv_size = 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        bigdl_input_ids = []
 | 
					        bigdl_input_ids = []
 | 
				
			||||||
        bigdl_position_ids = []
 | 
					        # bigdl_position_ids = []
 | 
				
			||||||
        bigdl_attention_mask = []
 | 
					        bigdl_attention_mask = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cur_seq_ids = []
 | 
					        cur_seq_ids = []
 | 
				
			||||||
| 
						 | 
					@ -144,8 +148,12 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
        # 1. Assemble bigdl_input_ids end
 | 
					        # 1. Assemble bigdl_input_ids end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if is_decoding_stage:
 | 
					        if is_decoding_stage:
 | 
				
			||||||
            bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists,
 | 
					            construct_kv_cache_func = self.get_construct_kv_cache_func(enable_vllm_se_batching)
 | 
				
			||||||
                                                   kv_cache, num_layers, decoder_kv_size)
 | 
					            bigdl_kv_cache = construct_kv_cache_func(cur_seq_ids,
 | 
				
			||||||
 | 
					                                                     seq_group_meta_data_lists,
 | 
				
			||||||
 | 
					                                                     kv_cache,
 | 
				
			||||||
 | 
					                                                     num_layers,
 | 
				
			||||||
 | 
					                                                     2)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len)
 | 
					            bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len)
 | 
				
			||||||
            bigdl_input_ids = [
 | 
					            bigdl_input_ids = [
 | 
				
			||||||
| 
						 | 
					@ -153,41 +161,72 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
                for input_ids in bigdl_input_ids
 | 
					                for input_ids in bigdl_input_ids
 | 
				
			||||||
            ]
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        decoding_attention_mask_list = []
 | 
				
			||||||
 | 
					        decoding_position_ids = []
 | 
				
			||||||
 | 
					        # num_layers x len(seq_id) x (2 x torch.Tensor)
 | 
				
			||||||
        if is_decoding_stage:
 | 
					        if is_decoding_stage:
 | 
				
			||||||
            cur_seq_len = bigdl_kv_cache[0][0].size(2)
 | 
					            if enable_vllm_se_batching:
 | 
				
			||||||
            for seq_group_meta_data in seq_group_meta_data_lists:
 | 
					                batch = 0
 | 
				
			||||||
                seq_ids = list(seq_group_meta_data.seq_data.keys())
 | 
					                for seq_group_meta_data in seq_group_meta_data_lists:
 | 
				
			||||||
                seq_id = seq_ids[0]
 | 
					                    # Get current seq_len in kv_cache
 | 
				
			||||||
                seq_data = seq_group_meta_data.seq_data[seq_id]
 | 
					                    current_seq_len = bigdl_kv_cache[0][batch][0].size(2)
 | 
				
			||||||
                cur_pos = seq_data.get_len()
 | 
					                    batch += 1
 | 
				
			||||||
                # bigdl_position_ids.append([cur_pos - 1])
 | 
					                    seq_ids = list(seq_group_meta_data.seq_data.keys())
 | 
				
			||||||
                cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
 | 
					                    seq_data = seq_group_meta_data.seq_data[seq_ids[0]]
 | 
				
			||||||
                bigdl_attention_mask.append(cur_attention_mask)
 | 
					                    cur_pos = seq_data.get_len()
 | 
				
			||||||
 | 
					                    decoding_position_ids.append(cur_pos - 1)
 | 
				
			||||||
 | 
					                    # Total length: current_seq_len + 1
 | 
				
			||||||
 | 
					                    cur_attention_mask = [0] * (current_seq_len - cur_pos + 1) + [1] * (cur_pos)
 | 
				
			||||||
 | 
					                    decoding_attention_mask_list.append(cur_attention_mask)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                cur_seq_len = bigdl_kv_cache[0][0].size(2)
 | 
				
			||||||
 | 
					                for seq_group_meta_data in seq_group_meta_data_lists:
 | 
				
			||||||
 | 
					                    seq_ids = list(seq_group_meta_data.seq_data.keys())
 | 
				
			||||||
 | 
					                    seq_id = seq_ids[0]
 | 
				
			||||||
 | 
					                    seq_data = seq_group_meta_data.seq_data[seq_id]
 | 
				
			||||||
 | 
					                    cur_pos = seq_data.get_len()
 | 
				
			||||||
 | 
					                    # bigdl_position_ids.append([cur_pos - 1])
 | 
				
			||||||
 | 
					                    # decoding_position_ids.append(cur_pos - 1)
 | 
				
			||||||
 | 
					                    cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
 | 
				
			||||||
 | 
					                    decoding_attention_mask_list.append(cur_attention_mask)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
 | 
					        bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if is_decoding_stage:
 | 
					        if is_decoding_stage:
 | 
				
			||||||
            # bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
 | 
					            if enable_vllm_se_batching:
 | 
				
			||||||
            bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
 | 
					                attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0)
 | 
				
			||||||
 | 
					                                  for x in decoding_attention_mask_list]
 | 
				
			||||||
 | 
					                position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                attention_mask = torch.tensor(decoding_attention_mask_list, device=self.device)
 | 
				
			||||||
 | 
					                position_ids = None
 | 
				
			||||||
            kwargs = {
 | 
					            kwargs = {
 | 
				
			||||||
                "input_ids": bigdl_input_ids,
 | 
					                "input_ids": bigdl_input_ids,
 | 
				
			||||||
                # "position_ids": bigdl_position_ids,
 | 
					                "position_ids": position_ids,
 | 
				
			||||||
                "attention_mask": bigdl_attention_mask,
 | 
					                "attention_mask": attention_mask,
 | 
				
			||||||
                "past_key_values": bigdl_kv_cache,
 | 
					                "past_key_values": bigdl_kv_cache,
 | 
				
			||||||
                "use_cache": True,
 | 
					                "use_cache": True,
 | 
				
			||||||
                # "return_dict": True,
 | 
					                # "return_dict": True,
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
 | 
					            # Prefill stage
 | 
				
			||||||
 | 
					            attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
 | 
				
			||||||
 | 
					            if enable_vllm_se_batching:
 | 
				
			||||||
 | 
					                position_ids = attention_mask.long().cumsum(-1) - 1
 | 
				
			||||||
 | 
					                position_ids.masked_fill_(attention_mask == 0, 1)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                position_ids = None
 | 
				
			||||||
            kwargs = {
 | 
					            kwargs = {
 | 
				
			||||||
                "input_ids": bigdl_input_ids,
 | 
					                "input_ids": bigdl_input_ids,
 | 
				
			||||||
                "attention_mask": torch.tensor(bigdl_attention_mask, device=self.device),
 | 
					                "attention_mask": attention_mask,
 | 
				
			||||||
                # "position_ids": bigdl_position_ids,
 | 
					                "position_ids": position_ids,
 | 
				
			||||||
                "past_key_values": None,
 | 
					                "past_key_values": None,
 | 
				
			||||||
                "use_cache": True,
 | 
					                "use_cache": True,
 | 
				
			||||||
                # "return_dict": True,
 | 
					                # "return_dict": True,
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					            # Prefill may need additional space, which forces us to delete the last_kv_cache
 | 
				
			||||||
            if self.last_kv_cache:
 | 
					            if self.last_kv_cache:
 | 
				
			||||||
                del self.last_kv_cache
 | 
					                self.last_kv_cache = None
 | 
				
			||||||
        # pdb.set_trace()
 | 
					        # pdb.set_trace()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.device.type == 'xpu':
 | 
					        if self.device.type == 'xpu':
 | 
				
			||||||
| 
						 | 
					@ -207,8 +246,12 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
        # tmp = torch.xpu.memory_stats()
 | 
					        # tmp = torch.xpu.memory_stats()
 | 
				
			||||||
        # logger.info(f"before: {tmp['allocated_bytes.all.current']}")
 | 
					        # logger.info(f"before: {tmp['allocated_bytes.all.current']}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.update_kv_cache(cur_seq_ids,
 | 
					        if enable_vllm_se_batching:
 | 
				
			||||||
                             kv_cache, num_layers, decoder_kv_size)
 | 
					            self.update_kv_cache_selective_batching(
 | 
				
			||||||
 | 
					                cur_seq_ids, kv_cache, num_layers, decoder_kv_size)
 | 
				
			||||||
 | 
					            self.last_kv_cache = None
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.update_kv_cache(cur_seq_ids, kv_cache, num_layers, decoder_kv_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # tmp = torch.xpu.memory_stats()
 | 
					        # tmp = torch.xpu.memory_stats()
 | 
				
			||||||
        # logger.info(f"after: {tmp['allocated_bytes.all.current']}")
 | 
					        # logger.info(f"after: {tmp['allocated_bytes.all.current']}")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -137,6 +137,34 @@ class BigDLModelForCausalLM(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return bigdl_kv_cache
 | 
					        return bigdl_kv_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_construct_kv_cache_func(self, enable_selective_batching):
 | 
				
			||||||
 | 
					        if enable_selective_batching:
 | 
				
			||||||
 | 
					            return self.prepare_kv_cache_selective_batching
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return self.prepare_kv_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # This is an implementation for models that KV Cache shape in (batch_size, num_heads,
 | 
				
			||||||
 | 
					    # sequence_length, embed_size_per_head).
 | 
				
			||||||
 | 
					    def prepare_kv_cache_selective_batching(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        cur_seq_ids: List[int],
 | 
				
			||||||
 | 
					        seq_group_meta_data_lists: List[SequenceGroupMetadata],
 | 
				
			||||||
 | 
					        kv_cache: Dict,
 | 
				
			||||||
 | 
					        num_layers: int,
 | 
				
			||||||
 | 
					        kv_cache_size_1: int,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        # Return bigdl_kv_cache in the format of Tuple(List[Tuple(torch.Tensor)])
 | 
				
			||||||
 | 
					        bigdl_kv_cache = []
 | 
				
			||||||
 | 
					        for i in range(num_layers):
 | 
				
			||||||
 | 
					            # Construct a list of tuple(tensor)
 | 
				
			||||||
 | 
					            temp_cache = []
 | 
				
			||||||
 | 
					            for seq_id in cur_seq_ids:
 | 
				
			||||||
 | 
					                key = kv_cache[i][0][seq_id]
 | 
				
			||||||
 | 
					                value = kv_cache[i][1][seq_id]
 | 
				
			||||||
 | 
					                temp_cache.append((key, value))
 | 
				
			||||||
 | 
					            bigdl_kv_cache.append(temp_cache)
 | 
				
			||||||
 | 
					        return bigdl_kv_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # This is an implementation for models that KV Cache shape in (batch_size, num_heads,
 | 
					    # This is an implementation for models that KV Cache shape in (batch_size, num_heads,
 | 
				
			||||||
    # sequence_length, embed_size_per_head).
 | 
					    # sequence_length, embed_size_per_head).
 | 
				
			||||||
    def update_kv_cache(
 | 
					    def update_kv_cache(
 | 
				
			||||||
| 
						 | 
					@ -153,6 +181,18 @@ class BigDLModelForCausalLM(nn.Module):
 | 
				
			||||||
                    kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim]
 | 
					                    kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim]
 | 
				
			||||||
                    batch_dim = batch_dim + 1
 | 
					                    batch_dim = batch_dim + 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update_kv_cache_selective_batching(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        cur_seq_ids: List[int],
 | 
				
			||||||
 | 
					        kv_cache,
 | 
				
			||||||
 | 
					        layer: int,
 | 
				
			||||||
 | 
					        kv_cache_size_1: int,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        for i in range(layer):
 | 
				
			||||||
 | 
					            for j in range(len(cur_seq_ids)):
 | 
				
			||||||
 | 
					                kv_cache[i][0][cur_seq_ids[j]] = self.last_kv_cache[i][j][0]
 | 
				
			||||||
 | 
					                kv_cache[i][1][cur_seq_ids[j]] = self.last_kv_cache[i][j][1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        seq_group_meta_data_lists: List[SequenceGroupMetadata],
 | 
					        seq_group_meta_data_lists: List[SequenceGroupMetadata],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue