LLM: Refine Pipeline Parallel FastAPI (#11587)
Refine Pipeline Parallel FastAPI
This commit is contained in:
parent
4d56ef5646
commit
060792a648
1 changed files with 18 additions and 9 deletions
|
|
@ -462,11 +462,15 @@ class BatchTask(BaseModel):
|
|||
partial_prefilling: int
|
||||
|
||||
|
||||
def make_attention_mask(prompt_lengths):
|
||||
def make_attention_mask(prompt_lengths, device):
|
||||
max_length = max(prompt_lengths)
|
||||
attention_mask = torch.zeros((len(prompt_lengths), max_length), dtype=torch.int64)
|
||||
for i, length in enumerate(prompt_lengths):
|
||||
attention_mask[i, max_length - length:] = 1
|
||||
batch_size = len(prompt_lengths)
|
||||
|
||||
range_tensor = torch.arange(max_length, device=device).expand(batch_size, max_length)
|
||||
prompt_lengths_tensor = torch.tensor(prompt_lengths, device=device).unsqueeze(1)
|
||||
attention_mask = range_tensor >= max_length - prompt_lengths_tensor
|
||||
attention_mask = attention_mask.to(torch.int64)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
|
|
@ -501,6 +505,8 @@ class PPModelWorker:
|
|||
self.print_len = {}
|
||||
self.is_finish = {}
|
||||
self.model_name = checkpoint
|
||||
|
||||
self.device = f"xpu:{self.rank}"
|
||||
# self.layer_start = 0
|
||||
# self.layer_end = 0
|
||||
|
||||
|
|
@ -543,7 +549,7 @@ class PPModelWorker:
|
|||
return cur_batch
|
||||
|
||||
def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2):
|
||||
if model_type in ["baichuan", "chatglm"]:
|
||||
if model_type in ["baichuan", "chatglm", "mixtral"]:
|
||||
result = []
|
||||
for sub_tuple1, sub_tuple2 in zip(kv_cache_1, kv_cache_2):
|
||||
if sub_tuple1 is None:
|
||||
|
|
@ -613,7 +619,7 @@ class PPModelWorker:
|
|||
# logger.info(f"{self.rank} {cur_batch} {input.shape}")
|
||||
cur_id = cur_batch.batch_id
|
||||
_past_key_values = self.past_key_values_dict.get(cur_id, None)
|
||||
attention_mask = make_attention_mask(cur_batch.prompt_lengths).to(input.device)
|
||||
attention_mask = make_attention_mask(cur_batch.prompt_lengths, input.device)
|
||||
|
||||
if self.rank == 0:
|
||||
input_ids = input
|
||||
|
|
@ -637,7 +643,7 @@ class PPModelWorker:
|
|||
tmp_past_key_values = _past_key_values
|
||||
_past_key_values = None
|
||||
|
||||
# torch.xpu.empty_cache()
|
||||
torch.xpu.empty_cache()
|
||||
output = self.model(input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=_past_key_values,
|
||||
|
|
@ -661,7 +667,7 @@ class PPModelWorker:
|
|||
|
||||
if self.pp_config.is_tail:
|
||||
_pre_output = self.partial_output_dict.get(cur_id, None)
|
||||
tmp_output = output.logits.to(self.dtype)
|
||||
tmp_output = output.logits
|
||||
tmp_output = torch.argmax(tmp_output[:, -1:, :], dim=-1)
|
||||
if _pre_output is None:
|
||||
_pre_output = tmp_output
|
||||
|
|
@ -674,7 +680,10 @@ class PPModelWorker:
|
|||
self.past_key_values_dict[cur_id] = _past_key_values
|
||||
torch.xpu.synchronize()
|
||||
if not self.pp_config.is_tail:
|
||||
return output[0].to(self.dtype), cur_batch
|
||||
_output = output[0]
|
||||
if _output.dtype != self.dtype:
|
||||
_output = _output.to(self.dtype)
|
||||
return _output, cur_batch
|
||||
else:
|
||||
if cur_batch.partial_prefilling > 0 and \
|
||||
cur_batch.prefilled_index == cur_batch.batch_size:
|
||||
|
|
|
|||
Loading…
Reference in a new issue