LLM: Refine Pipeline Parallel FastAPI (#11587)

Refine Pipeline Parallel FastAPI
This commit is contained in:
Xiangyu Tian 2024-07-22 15:52:05 +08:00 committed by GitHub
parent 4d56ef5646
commit 060792a648
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -462,11 +462,15 @@ class BatchTask(BaseModel):
partial_prefilling: int
def make_attention_mask(prompt_lengths):
def make_attention_mask(prompt_lengths, device):
max_length = max(prompt_lengths)
attention_mask = torch.zeros((len(prompt_lengths), max_length), dtype=torch.int64)
for i, length in enumerate(prompt_lengths):
attention_mask[i, max_length - length:] = 1
batch_size = len(prompt_lengths)
range_tensor = torch.arange(max_length, device=device).expand(batch_size, max_length)
prompt_lengths_tensor = torch.tensor(prompt_lengths, device=device).unsqueeze(1)
attention_mask = range_tensor >= max_length - prompt_lengths_tensor
attention_mask = attention_mask.to(torch.int64)
return attention_mask
@ -501,6 +505,8 @@ class PPModelWorker:
self.print_len = {}
self.is_finish = {}
self.model_name = checkpoint
self.device = f"xpu:{self.rank}"
# self.layer_start = 0
# self.layer_end = 0
@ -543,7 +549,7 @@ class PPModelWorker:
return cur_batch
def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2):
if model_type in ["baichuan", "chatglm"]:
if model_type in ["baichuan", "chatglm", "mixtral"]:
result = []
for sub_tuple1, sub_tuple2 in zip(kv_cache_1, kv_cache_2):
if sub_tuple1 is None:
@ -613,7 +619,7 @@ class PPModelWorker:
# logger.info(f"{self.rank} {cur_batch} {input.shape}")
cur_id = cur_batch.batch_id
_past_key_values = self.past_key_values_dict.get(cur_id, None)
attention_mask = make_attention_mask(cur_batch.prompt_lengths).to(input.device)
attention_mask = make_attention_mask(cur_batch.prompt_lengths, input.device)
if self.rank == 0:
input_ids = input
@ -637,7 +643,7 @@ class PPModelWorker:
tmp_past_key_values = _past_key_values
_past_key_values = None
# torch.xpu.empty_cache()
torch.xpu.empty_cache()
output = self.model(input_ids=input_ids,
inputs_embeds=inputs_embeds,
past_key_values=_past_key_values,
@ -661,7 +667,7 @@ class PPModelWorker:
if self.pp_config.is_tail:
_pre_output = self.partial_output_dict.get(cur_id, None)
tmp_output = output.logits.to(self.dtype)
tmp_output = output.logits
tmp_output = torch.argmax(tmp_output[:, -1:, :], dim=-1)
if _pre_output is None:
_pre_output = tmp_output
@ -674,7 +680,10 @@ class PPModelWorker:
self.past_key_values_dict[cur_id] = _past_key_values
torch.xpu.synchronize()
if not self.pp_config.is_tail:
return output[0].to(self.dtype), cur_batch
_output = output[0]
if _output.dtype != self.dtype:
_output = _output.to(self.dtype)
return _output, cur_batch
else:
if cur_batch.partial_prefilling > 0 and \
cur_batch.prefilled_index == cur_batch.batch_size: