LLM: Partial Prefilling for Pipeline Parallel Serving (#11457)
LLM: Partial Prefilling for Pipeline Parallel Serving
This commit is contained in:
parent
72b4efaad4
commit
7d8bc83415
4 changed files with 251 additions and 92 deletions
|
|
@ -57,8 +57,12 @@ pip install trl==0.8.1
|
|||
bash run.sh
|
||||
```
|
||||
|
||||
> Note: INT4 optimization is applied to the model by default. You could specify other low bit optimizations (such as 'fp8' and 'fp6') through `--low-bit`. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.
|
||||
### Command Line Arguments in `run.sh`
|
||||
> Note: INT4 optimization is applied to the model by default. You could specify other low bit optimizations (such as 'fp8' and 'fp6') through `--low-bit`. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine. Other relative settings are listed below:
|
||||
|
||||
- `--low-bit`: Sets the low bit optimizations (such as 'sym_int4', 'fp16', 'fp8' and 'fp6') for the model.
|
||||
- `--max-num-seqs`: Sets the maximum batch size on a single card during pipeline parallel serving.
|
||||
- `--max-prefilled-seqs`: Sets the maximum batch size for prefilled sequences. Use `0` to disable partial prefetching and process all requests in a single batch.
|
||||
|
||||
### 3. Sample Input and Output
|
||||
|
||||
|
|
|
|||
|
|
@ -306,18 +306,21 @@ async def main():
|
|||
help='The port number on which the server will run.')
|
||||
parser.add_argument('--max-num-seqs', type=int, default=8,
|
||||
help='Max num sequences in a batch.')
|
||||
parser.add_argument('--max-prefilled-seqs', type=int, default=0,
|
||||
help='Max num sequences in a batch during prefilling.')
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.repo_id_or_model_path
|
||||
low_bit = args.low_bit
|
||||
max_num_seqs = args.max_num_seqs
|
||||
max_prefilled_seqs = args.max_prefilled_seqs
|
||||
|
||||
# serialize model initialization so that we do not run out of CPU memory
|
||||
for i in range(my_size):
|
||||
if my_rank == i:
|
||||
logger.info("start model initialization")
|
||||
global local_model
|
||||
local_model = ModelRunner(model_path, my_rank, my_size, low_bit, max_num_seqs)
|
||||
local_model = ModelRunner(model_path, my_rank, my_size, low_bit, max_num_seqs, max_prefilled_seqs)
|
||||
logger.info("model initialized")
|
||||
dist.barrier()
|
||||
# Load tokenizer
|
||||
|
|
|
|||
|
|
@ -24,11 +24,13 @@ source $basekit_root/setvars.sh --force
|
|||
source $basekit_root/ccl/latest/env/vars.sh --force
|
||||
|
||||
export USE_XETLA=OFF
|
||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
|
||||
if [[ $KERNEL_VERSION != *"6.5"* ]]; then
|
||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||
fi
|
||||
export TORCH_LLM_ALLREDUCE=0
|
||||
|
||||
export MODEL_PATH=YOUR_MODEL_PATH
|
||||
export NUM_GPUS=2
|
||||
export IPEX_LLM_QUANTIZE_KV_CACHE=1
|
||||
|
||||
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 --max-num-seqs 4
|
||||
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 --max-num-seqs 4 --max-prefilled-seqs 0
|
||||
|
|
|
|||
|
|
@ -163,6 +163,9 @@ def pipeline_parallel(model, pipeline_parallel_stages):
|
|||
model._modules['lm_head'] = DummyLayer()
|
||||
|
||||
model.pipeline_parallel_stages = pipeline_parallel_stages
|
||||
model.layer_start = layer_start
|
||||
model.layer_end = layer_end
|
||||
model.num_layers = num_layers
|
||||
model = model.to(f'xpu:{local_rank}')
|
||||
return model
|
||||
|
||||
|
|
@ -364,6 +367,9 @@ class BatchTask(BaseModel):
|
|||
prompt_lengths: List[int]
|
||||
stopped: bool
|
||||
|
||||
prefilled_index: int
|
||||
partial_prefilling: int
|
||||
|
||||
|
||||
def make_attention_mask(prompt_lengths):
|
||||
max_length = max(prompt_lengths)
|
||||
|
|
@ -375,7 +381,7 @@ def make_attention_mask(prompt_lengths):
|
|||
|
||||
class ModelRunner:
|
||||
"""Implementation for pipeline parallel multi-stage serving."""
|
||||
def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs,
|
||||
def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, max_prefilled_seqs,
|
||||
torch_dtype=torch.float16):
|
||||
self.pp_config = PPConfig(rank, world_size)
|
||||
self.dtype = torch_dtype
|
||||
|
|
@ -404,7 +410,11 @@ class ModelRunner:
|
|||
self.print_len = {}
|
||||
self.is_finish = {}
|
||||
self.model_name = checkpoint
|
||||
self.layer_start = 0
|
||||
# self.layer_start = 0
|
||||
# self.layer_end = 0
|
||||
|
||||
self.max_prefilled_seqs = max_prefilled_seqs
|
||||
self.partial_output_dict = {}
|
||||
|
||||
def load_model(self, model_path, world_size, low_bit='sym_int4'):
|
||||
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
|
||||
|
|
@ -427,11 +437,90 @@ class ModelRunner:
|
|||
model = model.eval()
|
||||
return model
|
||||
|
||||
def prepare_batch(self, cur_batch):
|
||||
if self.rank == 0:
|
||||
cur_input_start = cur_batch.prefilled_index
|
||||
if self.max_prefilled_seqs > 0:
|
||||
if cur_input_start < cur_batch.batch_size:
|
||||
cur_input_end = cur_input_start + self.max_prefilled_seqs
|
||||
cur_input_end = min(cur_input_end, cur_batch.batch_size)
|
||||
cur_batch.partial_prefilling = cur_input_end - cur_input_start
|
||||
else:
|
||||
cur_batch.partial_prefilling = 0
|
||||
|
||||
return cur_batch
|
||||
|
||||
def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2):
|
||||
if model_type in ["baichuan", "chatglm"]:
|
||||
result = []
|
||||
for sub_tuple1, sub_tuple2 in zip(kv_cache_1, kv_cache_2):
|
||||
if sub_tuple1 is None:
|
||||
sub_result = [sub_tuple2]
|
||||
elif sub_tuple2 is None:
|
||||
sub_result = [sub_tuple1]
|
||||
else:
|
||||
sub_result = []
|
||||
for t1, t2 in zip(sub_tuple1, sub_tuple2):
|
||||
if t1 is None:
|
||||
sub_result.append(t2)
|
||||
elif t2 is None:
|
||||
sub_result.append(t1)
|
||||
else:
|
||||
if model_type == "chatglm" and self.model.config.num_layers != 40:
|
||||
sub_result.append(torch.cat((t1, t2), dim=1))
|
||||
else:
|
||||
sub_result.append(torch.cat((t1, t2), dim=0))
|
||||
result.append(tuple(sub_result))
|
||||
return tuple(result)
|
||||
else:
|
||||
# num_layers = self.model.layer_end - self.model.layer_start
|
||||
for layer_idx in range(self.model.num_layers):
|
||||
kv_cache_1.key_cache[layer_idx] = \
|
||||
torch.cat([kv_cache_1.key_cache[layer_idx],
|
||||
kv_cache_2.key_cache[layer_idx]], dim=0)
|
||||
kv_cache_1.value_cache[layer_idx] = \
|
||||
torch.cat([kv_cache_1.value_cache[layer_idx],
|
||||
kv_cache_2.value_cache[layer_idx]], dim=0)
|
||||
|
||||
return kv_cache_1
|
||||
|
||||
def update_kv_cache(self, kv_cache, cur_id):
|
||||
layer_start = self.model.layer_start
|
||||
layer_end = self.model.layer_end
|
||||
num_layers = self.model.num_layers
|
||||
|
||||
if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
|
||||
# for glm-4-9b-chat
|
||||
if self.past_key_values_dict.get(cur_id, None) is None:
|
||||
value_placeholder = torch.empty_like((kv_cache)[-1][0])
|
||||
past_key_values_placeholder = tuple(
|
||||
(value_placeholder, value_placeholder) for _ in range(layer_start)
|
||||
) + (kv_cache)[:layer_end - layer_start] + tuple(
|
||||
(value_placeholder, value_placeholder) for _ in range(layer_end, num_layers)
|
||||
)
|
||||
kv_cache = past_key_values_placeholder
|
||||
else:
|
||||
pass
|
||||
elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
|
||||
value_placeholder = torch.empty_like((kv_cache)[-1][0])
|
||||
kv_cache = tuple((value_placeholder, value_placeholder)) + \
|
||||
tuple(None for _ in range(layer_start)) + \
|
||||
(kv_cache)[layer_start:]
|
||||
# past_key_values_placeholder = tuple(
|
||||
# (value_placeholder, value_placeholder) for _ in range(layer_start)
|
||||
# ) + (kv_cache)[layer_start:]
|
||||
# kv_cache = past_key_values_placeholder
|
||||
else:
|
||||
pass
|
||||
|
||||
return kv_cache
|
||||
|
||||
@torch.no_grad()
|
||||
def model_step(self, input, cur_batch):
|
||||
if cur_batch is None or cur_batch.stopped or input is None:
|
||||
return None
|
||||
return None, cur_batch
|
||||
|
||||
# logger.info(f"{self.rank} {cur_batch} {input.shape}")
|
||||
cur_id = cur_batch.batch_id
|
||||
_past_key_values = self.past_key_values_dict.get(cur_id, None)
|
||||
attention_mask = make_attention_mask(cur_batch.prompt_lengths).to(input.device)
|
||||
|
|
@ -439,44 +528,71 @@ class ModelRunner:
|
|||
if self.rank == 0:
|
||||
input_ids = input
|
||||
inputs_embeds = None
|
||||
|
||||
if cur_batch.partial_prefilling > 0:
|
||||
cur_input_start = cur_batch.prefilled_index
|
||||
cur_input_end = cur_input_start + cur_batch.partial_prefilling
|
||||
input_ids = input_ids[cur_input_start:cur_input_end]
|
||||
attention_mask = attention_mask[cur_input_start:cur_input_end]
|
||||
tmp_past_key_values = _past_key_values
|
||||
_past_key_values = None
|
||||
else:
|
||||
input_ids = None
|
||||
inputs_embeds = input
|
||||
|
||||
torch.xpu.empty_cache()
|
||||
if cur_batch.partial_prefilling > 0:
|
||||
cur_input_start = cur_batch.prefilled_index
|
||||
cur_input_end = cur_input_start + cur_batch.partial_prefilling
|
||||
attention_mask = attention_mask[cur_input_start:cur_input_end]
|
||||
tmp_past_key_values = _past_key_values
|
||||
_past_key_values = None
|
||||
|
||||
# torch.xpu.empty_cache()
|
||||
output = self.model(input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=_past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,)
|
||||
|
||||
if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
|
||||
# for glm-4-9b-chat
|
||||
if self.past_key_values_dict.get(cur_id, None) is None:
|
||||
value_placeholder = torch.empty_like((output.past_key_values)[-1][0])
|
||||
past_key_values_placeholder = tuple(
|
||||
(value_placeholder, value_placeholder) for _ in range(layer_start)
|
||||
) + (output.past_key_values)[: layer_end - layer_start] + tuple(
|
||||
(value_placeholder, value_placeholder) for _ in range(layer_end, num_layers)
|
||||
)
|
||||
_past_key_values = past_key_values_placeholder
|
||||
if cur_batch.partial_prefilling > 0:
|
||||
cur_batch.prefilled_index = cur_input_end
|
||||
if tmp_past_key_values is None:
|
||||
tmp_past_key_values = output.past_key_values
|
||||
else:
|
||||
_past_key_values = output.past_key_values
|
||||
elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
|
||||
# for baichuan2 and chatglm3
|
||||
value_placeholder = torch.empty_like((output.past_key_values)[-1][0])
|
||||
past_key_values_placeholder = tuple(
|
||||
(value_placeholder, value_placeholder) for _ in range(layer_start)
|
||||
) + (output.past_key_values)[layer_start:]
|
||||
_past_key_values = past_key_values_placeholder
|
||||
tmp_past_key_values = self.cat_kv_cache(self.model.config.model_type,
|
||||
tmp_past_key_values,
|
||||
output.past_key_values)
|
||||
# torch.xpu.empty_cache()
|
||||
|
||||
if cur_batch.prefilled_index == cur_batch.batch_size:
|
||||
tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, cur_id)
|
||||
|
||||
self.past_key_values_dict[cur_id] = tmp_past_key_values
|
||||
|
||||
if self.pp_config.is_tail:
|
||||
_pre_output = self.partial_output_dict.get(cur_id, None)
|
||||
tmp_output = output.logits.to(self.dtype)
|
||||
tmp_output = torch.argmax(tmp_output[:, -1:, :], dim=-1)
|
||||
if _pre_output is None:
|
||||
_pre_output = tmp_output
|
||||
else:
|
||||
_pre_output = torch.cat((_pre_output, tmp_output), dim=0)
|
||||
self.partial_output_dict[cur_id] = _pre_output
|
||||
else:
|
||||
_past_key_values = output.past_key_values
|
||||
self.past_key_values_dict[cur_id] = _past_key_values
|
||||
_past_key_values = self.update_kv_cache(output.past_key_values, cur_id)
|
||||
self.past_key_values_dict[cur_id] = _past_key_values
|
||||
torch.xpu.synchronize()
|
||||
if not self.pp_config.is_tail:
|
||||
return output[0].to(self.dtype)
|
||||
return output[0].to(self.dtype), cur_batch
|
||||
else:
|
||||
return output.logits
|
||||
if cur_batch.partial_prefilling > 0 and \
|
||||
cur_batch.prefilled_index == cur_batch.batch_size:
|
||||
_output = self.partial_output_dict.pop(cur_id, None)
|
||||
cur_batch.partial_prefilling = 0
|
||||
return _output, cur_batch
|
||||
else:
|
||||
_output = torch.argmax(output.logits[:, -1:, :], dim=-1)
|
||||
return _output, cur_batch
|
||||
|
||||
def is_initialized(self):
|
||||
return True
|
||||
|
|
@ -504,6 +620,8 @@ class ModelRunner:
|
|||
input_len=input_ids.size(1),
|
||||
prompt_lengths=[sum(attention_mask[i, :]) for i in range(input_ids.size(0))],
|
||||
stopped=False,
|
||||
prefilled_index=0,
|
||||
partial_prefilling=0,
|
||||
)
|
||||
|
||||
self.input_ids_dict[new_batch.batch_id] = input_ids
|
||||
|
|
@ -517,11 +635,15 @@ class ModelRunner:
|
|||
self.token_times.pop(cur_id, None)
|
||||
self.past_key_values_dict.pop(cur_id, None)
|
||||
|
||||
self.is_finish.pop(cur_id, None)
|
||||
self.partial_output_dict.pop(cur_id, None)
|
||||
|
||||
async def process_step(self, tokenizer, result_dict):
|
||||
cur_batch = None
|
||||
|
||||
if self.rank == 0:
|
||||
if self.send_buff is not None:
|
||||
# logger.info(f"send {self.rank} {self.send_buff.shape}")
|
||||
dist.send(self.send_buff, dst=self.next_rank)
|
||||
|
||||
if self.on_going_batches[0] is not None:
|
||||
|
|
@ -530,6 +652,7 @@ class ModelRunner:
|
|||
|
||||
if cur_batch is None:
|
||||
if not self.waiting_requests.empty():
|
||||
# wait more requests to be put in self.waiting_requests
|
||||
await asyncio.sleep(0.01)
|
||||
cur_batch = await self.add_request(tokenizer)
|
||||
cur_input = self.input_ids_dict[cur_batch.batch_id]
|
||||
|
|
@ -539,84 +662,99 @@ class ModelRunner:
|
|||
|
||||
if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None):
|
||||
cur_id = cur_batch.batch_id
|
||||
next_ids = torch.empty((cur_batch.batch_size, 1,), device=f'xpu:{self.rank}',
|
||||
dtype=torch.int64)
|
||||
# cur_batch = self.prepare_batch(cur_batch)
|
||||
if cur_batch.prefilled_index >= cur_batch.batch_size:
|
||||
cur_batch.partial_prefilling = 0
|
||||
if cur_batch.partial_prefilling > 0:
|
||||
next_ids = torch.empty((cur_batch.partial_prefilling, 1,),
|
||||
device=f'xpu:{self.rank}', dtype=torch.int64)
|
||||
else:
|
||||
next_ids = torch.empty((cur_batch.batch_size, 1,),
|
||||
device=f'xpu:{self.rank}', dtype=torch.int64)
|
||||
|
||||
# logger.info(f"recv {self.rank} {next_ids.shape}")
|
||||
dist.recv(next_ids, src=self.pre_rank)
|
||||
|
||||
if self.tokens.get(cur_id, None) is None:
|
||||
self.tokens[cur_id] = []
|
||||
if cur_batch.partial_prefilling > 0:
|
||||
cur_input = self.input_ids_dict[cur_batch.batch_id]
|
||||
else:
|
||||
if self.tokens.get(cur_id, None) is None:
|
||||
self.tokens[cur_id] = []
|
||||
|
||||
if len(next_ids.shape) == 1:
|
||||
next_ids = next_ids.unsqueeze(0)
|
||||
self.tokens[cur_id].append(next_ids)
|
||||
self.token_times[cur_id].append(time.perf_counter())
|
||||
cur_input = next_ids
|
||||
cur_batch.input_len = 1
|
||||
cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths]
|
||||
if len(next_ids.shape) == 1:
|
||||
next_ids = next_ids.unsqueeze(0)
|
||||
self.tokens[cur_id].append(next_ids)
|
||||
self.token_times[cur_id].append(time.perf_counter())
|
||||
cur_input = next_ids
|
||||
cur_batch.input_len = 1
|
||||
cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths]
|
||||
|
||||
for index, request_id in enumerate(cur_batch.request_ids):
|
||||
for index, request_id in enumerate(cur_batch.request_ids):
|
||||
|
||||
if not self.is_finish.get(request_id, False):
|
||||
remain = cur_batch.max_tokens - len(self.tokens[cur_id])
|
||||
if not self.is_finish.get(request_id, False):
|
||||
remain = cur_batch.max_tokens - len(self.tokens[cur_id])
|
||||
|
||||
if self.streamer.get(request_id, None) is None:
|
||||
self.streamer[request_id] = asyncio.Queue()
|
||||
if self.streamer.get(request_id, None) is None:
|
||||
self.streamer[request_id] = asyncio.Queue()
|
||||
|
||||
# Currently ignore eos for benchmark
|
||||
# if next_ids[index].int() == tokenizer.eos_token_id:
|
||||
# remain = 0
|
||||
# self.is_finish[request_id] = True
|
||||
# Currently ignore eos for benchmark
|
||||
# if next_ids[index].int() == tokenizer.eos_token_id:
|
||||
# remain = 0
|
||||
# self.is_finish[request_id] = True
|
||||
|
||||
if self.token_cache.get(request_id, None) is None:
|
||||
self.token_cache[request_id] = []
|
||||
self.print_len[request_id] = 0
|
||||
self.token_cache[request_id].extend(next_ids[index].tolist())
|
||||
if self.token_cache.get(request_id, None) is None:
|
||||
self.token_cache[request_id] = []
|
||||
self.print_len[request_id] = 0
|
||||
self.token_cache[request_id].extend(next_ids[index].tolist())
|
||||
|
||||
text = tokenizer.decode(self.token_cache[request_id])
|
||||
if text.endswith("\n"):
|
||||
printable_text = text[self.print_len[request_id]:]
|
||||
self.token_cache[request_id] = []
|
||||
self.print_len[request_id] = 0
|
||||
elif len(text) > 0 and _is_chinese_char(ord(text[-1])):
|
||||
printable_text = text[self.print_len[request_id]:]
|
||||
self.print_len[request_id] += len(printable_text)
|
||||
else:
|
||||
printable_text = text[self.print_len[request_id]: text.rfind(" ") + 1]
|
||||
self.print_len[request_id] += len(printable_text)
|
||||
text = tokenizer.decode(self.token_cache[request_id])
|
||||
if text.endswith("\n"):
|
||||
printable_text = text[self.print_len[request_id]:]
|
||||
self.token_cache[request_id] = []
|
||||
self.print_len[request_id] = 0
|
||||
elif len(text) > 0 and _is_chinese_char(ord(text[-1])):
|
||||
printable_text = text[self.print_len[request_id]:]
|
||||
self.print_len[request_id] += len(printable_text)
|
||||
else:
|
||||
r_index = text.rfind(" ") + 1
|
||||
printable_text = text[self.print_len[request_id]: r_index]
|
||||
self.print_len[request_id] += len(printable_text)
|
||||
|
||||
if remain > 0:
|
||||
await self.streamer[request_id].put((remain, printable_text))
|
||||
else:
|
||||
printable_text = printable_text + text[self.print_len[request_id]:]
|
||||
self.token_cache.pop(request_id, None)
|
||||
self.print_len.pop(request_id, None)
|
||||
await self.streamer[request_id].put((remain, printable_text))
|
||||
if remain > 0:
|
||||
await self.streamer[request_id].put((remain, printable_text))
|
||||
else:
|
||||
printable_text = printable_text + text[self.print_len[request_id]:]
|
||||
self.token_cache.pop(request_id, None)
|
||||
self.print_len.pop(request_id, None)
|
||||
await self.streamer[request_id].put((remain, printable_text))
|
||||
|
||||
if len(self.tokens[cur_id]) >= cur_batch.max_tokens:
|
||||
# Finish a batch
|
||||
outputs = torch.cat(self.tokens[cur_id], dim=1)
|
||||
outputs = outputs.cpu()
|
||||
output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=False)
|
||||
for request_id, output_str in zip(cur_batch.request_ids, output_strs):
|
||||
with self.dict_lock:
|
||||
result_dict[request_id] = output_str
|
||||
if len(self.tokens[cur_id]) >= cur_batch.max_tokens:
|
||||
# Finish a batch
|
||||
outputs = torch.cat(self.tokens[cur_id], dim=1)
|
||||
outputs = outputs.cpu()
|
||||
output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=False)
|
||||
for request_id, output_str in zip(cur_batch.request_ids, output_strs):
|
||||
with self.dict_lock:
|
||||
result_dict[request_id] = output_str
|
||||
|
||||
cur_times = self.token_times[cur_id]
|
||||
first_token = cur_times[1] - cur_times[0]
|
||||
next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1)
|
||||
logger.info(f"First token latency: {first_token}, "
|
||||
f"next token latency: {next_token}")
|
||||
self.clear_batch(cur_id)
|
||||
cur_batch.stopped = True
|
||||
cur_times = self.token_times[cur_id]
|
||||
first_token = cur_times[1] - cur_times[0]
|
||||
next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1)
|
||||
logger.info(f"First token latency: {first_token}, "
|
||||
f"next token latency: {next_token}")
|
||||
self.clear_batch(cur_id)
|
||||
cur_batch.stopped = True
|
||||
else:
|
||||
if (cur_batch is not None) and cur_batch.stopped:
|
||||
cur_batch = None
|
||||
|
||||
if cur_batch is not None:
|
||||
cur_batch = self.prepare_batch(cur_batch)
|
||||
dist.broadcast_object_list([cur_batch], src=0)
|
||||
|
||||
else:
|
||||
if self.send_buff is not None:
|
||||
# logger.info(f"send {self.rank} {self.send_buff.shape}")
|
||||
dist.send(self.send_buff, dst=self.next_rank)
|
||||
|
||||
batch_list = [None]
|
||||
|
|
@ -629,14 +767,26 @@ class ModelRunner:
|
|||
if cur_batch.stopped:
|
||||
self.clear_batch(cur_batch.batch_id)
|
||||
else:
|
||||
cur_batch = self.prepare_batch(cur_batch)
|
||||
cur_len = cur_batch.input_len
|
||||
cur_input = torch.empty((cur_batch.batch_size, cur_len, self.hidden_size,),
|
||||
device=f'xpu:{self.rank}', dtype=self.dtype)
|
||||
if cur_batch.partial_prefilling:
|
||||
cur_input = torch.empty(
|
||||
(cur_batch.partial_prefilling, cur_len, self.hidden_size,),
|
||||
device=f'xpu:{self.rank}',
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
cur_input = torch.empty(
|
||||
(cur_batch.batch_size, cur_len, self.hidden_size,),
|
||||
device=f'xpu:{self.rank}',
|
||||
dtype=self.dtype,
|
||||
)
|
||||
# logger.info(f"recv {self.rank} {cur_input.shape}")
|
||||
dist.recv(cur_input, src=self.pre_rank)
|
||||
|
||||
output = self.model_step(cur_input, cur_batch)
|
||||
if output is not None and self.rank == self.world_size - 1:
|
||||
output = torch.argmax(output[:, -1:, :], dim=-1)
|
||||
output, cur_batch = self.model_step(cur_input, cur_batch)
|
||||
# if output is not None and self.rank == self.world_size - 1:
|
||||
# output = torch.argmax(output[:, -1:, :], dim=-1)
|
||||
|
||||
if output is not None:
|
||||
# dist.send(output, dst=self.next_rank)
|
||||
|
|
|
|||
Loading…
Reference in a new issue