diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index d3630362..55bb0f26 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -317,6 +317,8 @@ def llama_attention_selective_batching_forward_4_31( padding_mask: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Minimize this value to reduce memory allocation. + KV_CACHE_ALLOC_BLOCK_LENGTH = 64 bsz, q_len, _ = hidden_states.size() device = hidden_states.device # for flash attention @@ -334,108 +336,141 @@ def llama_attention_selective_batching_forward_4_31( attention_dtype = original_dtype # TODO: decoding fast path - # use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - # enough_kv_room = is_enough_kv_cache_room(past_key_value[0]) - # is_q4_0 = self.q_proj.qtype == SYM_INT4 - # no_tp = not self.config.pretraining_tp > 1 - # decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and - # enough_kv_room and bsz * q_len == 1) + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0]) + is_q4_0 = self.q_proj.qtype == SYM_INT4 + no_tp = not self.config.pretraining_tp > 1 + decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and + bsz * q_len == 1) + updated_past_key_values = [] # single batch decoding fast path # forward_qkv takes will perform QKV projection, rotary position embedding # and save the key/value states to cache, then return query states and the # extended key/value cache - # if decoding_fast_path: - # hidden_states = hidden_states.view(1, -1) - # kv_seq_len = past_key_value[0].shape[-2] - # cache_k = past_key_value[0] - # cache_v = past_key_value[1] - # import linear_q4_0 - # query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states, - # self.q_proj.weight, - # self.k_proj.weight, - # self.v_proj.weight, - # position_ids, - # cache_k, cache_v, - # self.q_proj.weight.qtype, - # kv_seq_len, - # self.head_dim) - # kv_seq_len += 1 - - # else: - if self.config.pretraining_tp > 1: - invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet") + if decoding_fast_path: + past_k = past_key_value[0][0] + past_v = past_key_value[0][1] + kv_seq_len = past_k.shape[-2] + if not enough_kv_room: + new_cache_k, new_cache_v = extend_kv_cache(1, + self.num_key_value_heads, # Support GQA + self.head_dim, + kv_seq_len, + kv_seq_len + + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=past_k.dtype, + device=device) + new_cache_k[:] = past_k + new_cache_v[:] = past_v + past_k = new_cache_k + past_v = new_cache_v + hidden_states = hidden_states.view(1, -1) + import linear_q4_0 + query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + position_ids, + past_k, past_v, + self.q_proj.weight.qtype, + kv_seq_len, + self.head_dim) + kv_seq_len += 1 else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.config.pretraining_tp > 1: + invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet") + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, - self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, + query_states = query_states.view(bsz, q_len, + self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value) - # TODO: fuse_rope - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "llama") + if use_fuse_rope: + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "llama") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "llama") - updated_past_key_values = [] - if past_key_value is not None: - batched_attention_output = [] - # print(f"type of attention_mask is {type(attention_mask)}") - for batch in range(bsz): - past_k, past_v = past_key_value[batch] - current_kv_len = past_k.shape[-2] + 1 + if past_key_value is not None: + batched_attention_output = [] + # print(f"type of attention_mask is {type(attention_mask)}") + for batch in range(bsz): + enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value[batch]) + past_k, past_v = past_key_value[batch] + current_kv_len = past_k.shape[-2] + 1 + if not enough_kv_room: + # allocate new + new_cache_k, new_cache_v = extend_kv_cache(1, + self.num_key_value_heads, + self.head_dim, + past_k.size(2), + current_kv_len + + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=past_k.dtype, + device=device) + new_cache_k[:] = past_k + new_cache_v[:] = past_v + past_k = new_cache_k + past_v = new_cache_v - current_key_states = torch.cat([past_k, - key_states[batch: batch + 1, :, :, :]], dim=2) - current_value_states = torch.cat([past_v, - value_states[batch: batch + 1, :, :, :]], dim=2) + current_key_states = key_states[batch: batch + 1, :, :, :] + current_value_states = value_states[batch: batch + 1, :, :, :] + current_key_states, current_value_states = append_kv_cache(past_k, + past_v, + current_key_states, + current_value_states) + updated_past_key_values.append((current_key_states, current_value_states)) - updated_past_key_values.append((current_key_states, current_value_states)) + current_key_states = repeat_kv(current_key_states, self.num_key_value_groups) + current_value_states = repeat_kv(current_value_states, self.num_key_value_groups) - current_key_states = repeat_kv(current_key_states, self.num_key_value_groups) - current_value_states = repeat_kv(current_value_states, self.num_key_value_groups) - - current_query_states = query_states[batch: batch + 1, :, :, :] - attn_output, attn_weights = native_sdp(current_query_states, - current_key_states, - current_value_states, - attention_mask[batch], - 1, - 1, - current_kv_len, - self.head_dim, - self.num_heads) - if attn_output.size() != (1, self.num_heads, 1, self.head_dim): + current_query_states = query_states[batch: batch + 1, :, :, :] + attn_output, attn_weights = native_sdp(current_query_states, + current_key_states, + current_value_states, + attention_mask[batch], + 1, + 1, + current_kv_len, + self.head_dim, + self.num_heads) + if attn_output.size() != (1, self.num_heads, 1, self.head_dim): + invalidInputError(False, + f"`attn_output` should be of size " + f"{(1, self.num_heads, 1, self.head_dim)}, but is" + f" {attn_output.size()}") + batched_attention_output.append(attn_output) + # For loop ends + # TODO: handle attention_weights later + attn_output = torch.concat(batched_attention_output, dim=0) + batched_attention_output.clear() + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): invalidInputError(False, f"`attn_output` should be of size " - f"{(1, self.num_heads, 1, self.head_dim)}, but is" + f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}") - batched_attention_output.append(attn_output) - # For loop ends - # TODO: handle attention_weights later - attn_output = torch.concat(batched_attention_output, dim=0) - batched_attention_output.clear() - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - invalidInputError(False, - f"`attn_output` should be of size " - f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - return attn_output, None, updated_past_key_values + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, None, updated_past_key_values - # TODO: Assume always use_cache - # print(f"prefill with batch size {bsz}") + # Assume always use_cache + # prefill or decoding fast path for batch in range(bsz): updated_past_key_values.append((key_states[batch: batch + 1, :, :, :], value_states[batch: batch+1, :, :, :])) @@ -445,6 +480,10 @@ def llama_attention_selective_batching_forward_4_31( dtype=attention_dtype) value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, dtype=attention_dtype) + # Can also happens for decoding fast path + if isinstance(attention_mask, list): + # For decoding fast path + attention_mask = attention_mask[0] attn_output, attn_weights = native_sdp(query_states, key_states, value_states, diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index eb6fa282..32af5970 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -196,7 +196,8 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): if enable_vllm_se_batching: attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0) for x in decoding_attention_mask_list] - position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1) + position_ids = torch.tensor(decoding_position_ids, device=self.device).long() + position_ids = position_ids.unsqueeze(-1) else: attention_mask = torch.tensor(decoding_attention_mask_list, device=self.device) position_ids = None @@ -214,6 +215,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): if enable_vllm_se_batching: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) + position_ids.to(self.device) else: position_ids = None kwargs = {