LLM: Partial Prefilling for Pipeline Parallel Serving (#11457)

LLM: Partial Prefilling for Pipeline Parallel Serving
This commit is contained in:
Xiangyu Tian 2024-07-05 13:10:35 +08:00 committed by GitHub
parent 72b4efaad4
commit 7d8bc83415
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 251 additions and 92 deletions

View file

@ -57,8 +57,12 @@ pip install trl==0.8.1
bash run.sh
```
> Note: INT4 optimization is applied to the model by default. You could specify other low bit optimizations (such as 'fp8' and 'fp6') through `--low-bit`. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.
### Command Line Arguments in `run.sh`
> Note: INT4 optimization is applied to the model by default. You could specify other low bit optimizations (such as 'fp8' and 'fp6') through `--low-bit`. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine. Other relative settings are listed below:
- `--low-bit`: Sets the low bit optimizations (such as 'sym_int4', 'fp16', 'fp8' and 'fp6') for the model.
- `--max-num-seqs`: Sets the maximum batch size on a single card during pipeline parallel serving.
- `--max-prefilled-seqs`: Sets the maximum batch size for prefilled sequences. Use `0` to disable partial prefetching and process all requests in a single batch.
### 3. Sample Input and Output

View file

@ -306,18 +306,21 @@ async def main():
help='The port number on which the server will run.')
parser.add_argument('--max-num-seqs', type=int, default=8,
help='Max num sequences in a batch.')
parser.add_argument('--max-prefilled-seqs', type=int, default=0,
help='Max num sequences in a batch during prefilling.')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
low_bit = args.low_bit
max_num_seqs = args.max_num_seqs
max_prefilled_seqs = args.max_prefilled_seqs
# serialize model initialization so that we do not run out of CPU memory
for i in range(my_size):
if my_rank == i:
logger.info("start model initialization")
global local_model
local_model = ModelRunner(model_path, my_rank, my_size, low_bit, max_num_seqs)
local_model = ModelRunner(model_path, my_rank, my_size, low_bit, max_num_seqs, max_prefilled_seqs)
logger.info("model initialized")
dist.barrier()
# Load tokenizer

View file

@ -24,11 +24,13 @@ source $basekit_root/setvars.sh --force
source $basekit_root/ccl/latest/env/vars.sh --force
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
if [[ $KERNEL_VERSION != *"6.5"* ]]; then
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
fi
export TORCH_LLM_ALLREDUCE=0
export MODEL_PATH=YOUR_MODEL_PATH
export NUM_GPUS=2
export IPEX_LLM_QUANTIZE_KV_CACHE=1
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 --max-num-seqs 4
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 --max-num-seqs 4 --max-prefilled-seqs 0

View file

@ -163,6 +163,9 @@ def pipeline_parallel(model, pipeline_parallel_stages):
model._modules['lm_head'] = DummyLayer()
model.pipeline_parallel_stages = pipeline_parallel_stages
model.layer_start = layer_start
model.layer_end = layer_end
model.num_layers = num_layers
model = model.to(f'xpu:{local_rank}')
return model
@ -364,6 +367,9 @@ class BatchTask(BaseModel):
prompt_lengths: List[int]
stopped: bool
prefilled_index: int
partial_prefilling: int
def make_attention_mask(prompt_lengths):
max_length = max(prompt_lengths)
@ -375,7 +381,7 @@ def make_attention_mask(prompt_lengths):
class ModelRunner:
"""Implementation for pipeline parallel multi-stage serving."""
def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs,
def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, max_prefilled_seqs,
torch_dtype=torch.float16):
self.pp_config = PPConfig(rank, world_size)
self.dtype = torch_dtype
@ -404,7 +410,11 @@ class ModelRunner:
self.print_len = {}
self.is_finish = {}
self.model_name = checkpoint
self.layer_start = 0
# self.layer_start = 0
# self.layer_end = 0
self.max_prefilled_seqs = max_prefilled_seqs
self.partial_output_dict = {}
def load_model(self, model_path, world_size, low_bit='sym_int4'):
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
@ -427,11 +437,90 @@ class ModelRunner:
model = model.eval()
return model
def prepare_batch(self, cur_batch):
if self.rank == 0:
cur_input_start = cur_batch.prefilled_index
if self.max_prefilled_seqs > 0:
if cur_input_start < cur_batch.batch_size:
cur_input_end = cur_input_start + self.max_prefilled_seqs
cur_input_end = min(cur_input_end, cur_batch.batch_size)
cur_batch.partial_prefilling = cur_input_end - cur_input_start
else:
cur_batch.partial_prefilling = 0
return cur_batch
def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2):
if model_type in ["baichuan", "chatglm"]:
result = []
for sub_tuple1, sub_tuple2 in zip(kv_cache_1, kv_cache_2):
if sub_tuple1 is None:
sub_result = [sub_tuple2]
elif sub_tuple2 is None:
sub_result = [sub_tuple1]
else:
sub_result = []
for t1, t2 in zip(sub_tuple1, sub_tuple2):
if t1 is None:
sub_result.append(t2)
elif t2 is None:
sub_result.append(t1)
else:
if model_type == "chatglm" and self.model.config.num_layers != 40:
sub_result.append(torch.cat((t1, t2), dim=1))
else:
sub_result.append(torch.cat((t1, t2), dim=0))
result.append(tuple(sub_result))
return tuple(result)
else:
# num_layers = self.model.layer_end - self.model.layer_start
for layer_idx in range(self.model.num_layers):
kv_cache_1.key_cache[layer_idx] = \
torch.cat([kv_cache_1.key_cache[layer_idx],
kv_cache_2.key_cache[layer_idx]], dim=0)
kv_cache_1.value_cache[layer_idx] = \
torch.cat([kv_cache_1.value_cache[layer_idx],
kv_cache_2.value_cache[layer_idx]], dim=0)
return kv_cache_1
def update_kv_cache(self, kv_cache, cur_id):
layer_start = self.model.layer_start
layer_end = self.model.layer_end
num_layers = self.model.num_layers
if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
# for glm-4-9b-chat
if self.past_key_values_dict.get(cur_id, None) is None:
value_placeholder = torch.empty_like((kv_cache)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
) + (kv_cache)[:layer_end - layer_start] + tuple(
(value_placeholder, value_placeholder) for _ in range(layer_end, num_layers)
)
kv_cache = past_key_values_placeholder
else:
pass
elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
value_placeholder = torch.empty_like((kv_cache)[-1][0])
kv_cache = tuple((value_placeholder, value_placeholder)) + \
tuple(None for _ in range(layer_start)) + \
(kv_cache)[layer_start:]
# past_key_values_placeholder = tuple(
# (value_placeholder, value_placeholder) for _ in range(layer_start)
# ) + (kv_cache)[layer_start:]
# kv_cache = past_key_values_placeholder
else:
pass
return kv_cache
@torch.no_grad()
def model_step(self, input, cur_batch):
if cur_batch is None or cur_batch.stopped or input is None:
return None
return None, cur_batch
# logger.info(f"{self.rank} {cur_batch} {input.shape}")
cur_id = cur_batch.batch_id
_past_key_values = self.past_key_values_dict.get(cur_id, None)
attention_mask = make_attention_mask(cur_batch.prompt_lengths).to(input.device)
@ -439,44 +528,71 @@ class ModelRunner:
if self.rank == 0:
input_ids = input
inputs_embeds = None
if cur_batch.partial_prefilling > 0:
cur_input_start = cur_batch.prefilled_index
cur_input_end = cur_input_start + cur_batch.partial_prefilling
input_ids = input_ids[cur_input_start:cur_input_end]
attention_mask = attention_mask[cur_input_start:cur_input_end]
tmp_past_key_values = _past_key_values
_past_key_values = None
else:
input_ids = None
inputs_embeds = input
torch.xpu.empty_cache()
if cur_batch.partial_prefilling > 0:
cur_input_start = cur_batch.prefilled_index
cur_input_end = cur_input_start + cur_batch.partial_prefilling
attention_mask = attention_mask[cur_input_start:cur_input_end]
tmp_past_key_values = _past_key_values
_past_key_values = None
# torch.xpu.empty_cache()
output = self.model(input_ids=input_ids,
inputs_embeds=inputs_embeds,
past_key_values=_past_key_values,
attention_mask=attention_mask,
use_cache=True,)
if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
# for glm-4-9b-chat
if self.past_key_values_dict.get(cur_id, None) is None:
value_placeholder = torch.empty_like((output.past_key_values)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
) + (output.past_key_values)[: layer_end - layer_start] + tuple(
(value_placeholder, value_placeholder) for _ in range(layer_end, num_layers)
)
_past_key_values = past_key_values_placeholder
if cur_batch.partial_prefilling > 0:
cur_batch.prefilled_index = cur_input_end
if tmp_past_key_values is None:
tmp_past_key_values = output.past_key_values
else:
_past_key_values = output.past_key_values
elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
# for baichuan2 and chatglm3
value_placeholder = torch.empty_like((output.past_key_values)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
) + (output.past_key_values)[layer_start:]
_past_key_values = past_key_values_placeholder
tmp_past_key_values = self.cat_kv_cache(self.model.config.model_type,
tmp_past_key_values,
output.past_key_values)
# torch.xpu.empty_cache()
if cur_batch.prefilled_index == cur_batch.batch_size:
tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, cur_id)
self.past_key_values_dict[cur_id] = tmp_past_key_values
if self.pp_config.is_tail:
_pre_output = self.partial_output_dict.get(cur_id, None)
tmp_output = output.logits.to(self.dtype)
tmp_output = torch.argmax(tmp_output[:, -1:, :], dim=-1)
if _pre_output is None:
_pre_output = tmp_output
else:
_pre_output = torch.cat((_pre_output, tmp_output), dim=0)
self.partial_output_dict[cur_id] = _pre_output
else:
_past_key_values = output.past_key_values
self.past_key_values_dict[cur_id] = _past_key_values
_past_key_values = self.update_kv_cache(output.past_key_values, cur_id)
self.past_key_values_dict[cur_id] = _past_key_values
torch.xpu.synchronize()
if not self.pp_config.is_tail:
return output[0].to(self.dtype)
return output[0].to(self.dtype), cur_batch
else:
return output.logits
if cur_batch.partial_prefilling > 0 and \
cur_batch.prefilled_index == cur_batch.batch_size:
_output = self.partial_output_dict.pop(cur_id, None)
cur_batch.partial_prefilling = 0
return _output, cur_batch
else:
_output = torch.argmax(output.logits[:, -1:, :], dim=-1)
return _output, cur_batch
def is_initialized(self):
return True
@ -504,6 +620,8 @@ class ModelRunner:
input_len=input_ids.size(1),
prompt_lengths=[sum(attention_mask[i, :]) for i in range(input_ids.size(0))],
stopped=False,
prefilled_index=0,
partial_prefilling=0,
)
self.input_ids_dict[new_batch.batch_id] = input_ids
@ -517,11 +635,15 @@ class ModelRunner:
self.token_times.pop(cur_id, None)
self.past_key_values_dict.pop(cur_id, None)
self.is_finish.pop(cur_id, None)
self.partial_output_dict.pop(cur_id, None)
async def process_step(self, tokenizer, result_dict):
cur_batch = None
if self.rank == 0:
if self.send_buff is not None:
# logger.info(f"send {self.rank} {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)
if self.on_going_batches[0] is not None:
@ -530,6 +652,7 @@ class ModelRunner:
if cur_batch is None:
if not self.waiting_requests.empty():
# wait more requests to be put in self.waiting_requests
await asyncio.sleep(0.01)
cur_batch = await self.add_request(tokenizer)
cur_input = self.input_ids_dict[cur_batch.batch_id]
@ -539,84 +662,99 @@ class ModelRunner:
if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None):
cur_id = cur_batch.batch_id
next_ids = torch.empty((cur_batch.batch_size, 1,), device=f'xpu:{self.rank}',
dtype=torch.int64)
# cur_batch = self.prepare_batch(cur_batch)
if cur_batch.prefilled_index >= cur_batch.batch_size:
cur_batch.partial_prefilling = 0
if cur_batch.partial_prefilling > 0:
next_ids = torch.empty((cur_batch.partial_prefilling, 1,),
device=f'xpu:{self.rank}', dtype=torch.int64)
else:
next_ids = torch.empty((cur_batch.batch_size, 1,),
device=f'xpu:{self.rank}', dtype=torch.int64)
# logger.info(f"recv {self.rank} {next_ids.shape}")
dist.recv(next_ids, src=self.pre_rank)
if self.tokens.get(cur_id, None) is None:
self.tokens[cur_id] = []
if cur_batch.partial_prefilling > 0:
cur_input = self.input_ids_dict[cur_batch.batch_id]
else:
if self.tokens.get(cur_id, None) is None:
self.tokens[cur_id] = []
if len(next_ids.shape) == 1:
next_ids = next_ids.unsqueeze(0)
self.tokens[cur_id].append(next_ids)
self.token_times[cur_id].append(time.perf_counter())
cur_input = next_ids
cur_batch.input_len = 1
cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths]
if len(next_ids.shape) == 1:
next_ids = next_ids.unsqueeze(0)
self.tokens[cur_id].append(next_ids)
self.token_times[cur_id].append(time.perf_counter())
cur_input = next_ids
cur_batch.input_len = 1
cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths]
for index, request_id in enumerate(cur_batch.request_ids):
for index, request_id in enumerate(cur_batch.request_ids):
if not self.is_finish.get(request_id, False):
remain = cur_batch.max_tokens - len(self.tokens[cur_id])
if not self.is_finish.get(request_id, False):
remain = cur_batch.max_tokens - len(self.tokens[cur_id])
if self.streamer.get(request_id, None) is None:
self.streamer[request_id] = asyncio.Queue()
if self.streamer.get(request_id, None) is None:
self.streamer[request_id] = asyncio.Queue()
# Currently ignore eos for benchmark
# if next_ids[index].int() == tokenizer.eos_token_id:
# remain = 0
# self.is_finish[request_id] = True
# Currently ignore eos for benchmark
# if next_ids[index].int() == tokenizer.eos_token_id:
# remain = 0
# self.is_finish[request_id] = True
if self.token_cache.get(request_id, None) is None:
self.token_cache[request_id] = []
self.print_len[request_id] = 0
self.token_cache[request_id].extend(next_ids[index].tolist())
if self.token_cache.get(request_id, None) is None:
self.token_cache[request_id] = []
self.print_len[request_id] = 0
self.token_cache[request_id].extend(next_ids[index].tolist())
text = tokenizer.decode(self.token_cache[request_id])
if text.endswith("\n"):
printable_text = text[self.print_len[request_id]:]
self.token_cache[request_id] = []
self.print_len[request_id] = 0
elif len(text) > 0 and _is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len[request_id]:]
self.print_len[request_id] += len(printable_text)
else:
printable_text = text[self.print_len[request_id]: text.rfind(" ") + 1]
self.print_len[request_id] += len(printable_text)
text = tokenizer.decode(self.token_cache[request_id])
if text.endswith("\n"):
printable_text = text[self.print_len[request_id]:]
self.token_cache[request_id] = []
self.print_len[request_id] = 0
elif len(text) > 0 and _is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len[request_id]:]
self.print_len[request_id] += len(printable_text)
else:
r_index = text.rfind(" ") + 1
printable_text = text[self.print_len[request_id]: r_index]
self.print_len[request_id] += len(printable_text)
if remain > 0:
await self.streamer[request_id].put((remain, printable_text))
else:
printable_text = printable_text + text[self.print_len[request_id]:]
self.token_cache.pop(request_id, None)
self.print_len.pop(request_id, None)
await self.streamer[request_id].put((remain, printable_text))
if remain > 0:
await self.streamer[request_id].put((remain, printable_text))
else:
printable_text = printable_text + text[self.print_len[request_id]:]
self.token_cache.pop(request_id, None)
self.print_len.pop(request_id, None)
await self.streamer[request_id].put((remain, printable_text))
if len(self.tokens[cur_id]) >= cur_batch.max_tokens:
# Finish a batch
outputs = torch.cat(self.tokens[cur_id], dim=1)
outputs = outputs.cpu()
output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=False)
for request_id, output_str in zip(cur_batch.request_ids, output_strs):
with self.dict_lock:
result_dict[request_id] = output_str
if len(self.tokens[cur_id]) >= cur_batch.max_tokens:
# Finish a batch
outputs = torch.cat(self.tokens[cur_id], dim=1)
outputs = outputs.cpu()
output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=False)
for request_id, output_str in zip(cur_batch.request_ids, output_strs):
with self.dict_lock:
result_dict[request_id] = output_str
cur_times = self.token_times[cur_id]
first_token = cur_times[1] - cur_times[0]
next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1)
logger.info(f"First token latency: {first_token}, "
f"next token latency: {next_token}")
self.clear_batch(cur_id)
cur_batch.stopped = True
cur_times = self.token_times[cur_id]
first_token = cur_times[1] - cur_times[0]
next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1)
logger.info(f"First token latency: {first_token}, "
f"next token latency: {next_token}")
self.clear_batch(cur_id)
cur_batch.stopped = True
else:
if (cur_batch is not None) and cur_batch.stopped:
cur_batch = None
if cur_batch is not None:
cur_batch = self.prepare_batch(cur_batch)
dist.broadcast_object_list([cur_batch], src=0)
else:
if self.send_buff is not None:
# logger.info(f"send {self.rank} {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)
batch_list = [None]
@ -629,14 +767,26 @@ class ModelRunner:
if cur_batch.stopped:
self.clear_batch(cur_batch.batch_id)
else:
cur_batch = self.prepare_batch(cur_batch)
cur_len = cur_batch.input_len
cur_input = torch.empty((cur_batch.batch_size, cur_len, self.hidden_size,),
device=f'xpu:{self.rank}', dtype=self.dtype)
if cur_batch.partial_prefilling:
cur_input = torch.empty(
(cur_batch.partial_prefilling, cur_len, self.hidden_size,),
device=f'xpu:{self.rank}',
dtype=self.dtype,
)
else:
cur_input = torch.empty(
(cur_batch.batch_size, cur_len, self.hidden_size,),
device=f'xpu:{self.rank}',
dtype=self.dtype,
)
# logger.info(f"recv {self.rank} {cur_input.shape}")
dist.recv(cur_input, src=self.pre_rank)
output = self.model_step(cur_input, cur_batch)
if output is not None and self.rank == self.world_size - 1:
output = torch.argmax(output[:, -1:, :], dim=-1)
output, cur_batch = self.model_step(cur_input, cur_batch)
# if output is not None and self.rank == self.world_size - 1:
# output = torch.argmax(output[:, -1:, :], dim=-1)
if output is not None:
# dist.send(output, dst=self.next_rank)