LLM: Refine Pipeline Parallel FastAPI (#11587)
Refine Pipeline Parallel FastAPI
This commit is contained in:
		
							parent
							
								
									4d56ef5646
								
							
						
					
					
						commit
						060792a648
					
				
					 1 changed files with 18 additions and 9 deletions
				
			
		| 
						 | 
				
			
			@ -462,11 +462,15 @@ class BatchTask(BaseModel):
 | 
			
		|||
    partial_prefilling: int
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_attention_mask(prompt_lengths):
 | 
			
		||||
def make_attention_mask(prompt_lengths, device):
 | 
			
		||||
    max_length = max(prompt_lengths)
 | 
			
		||||
    attention_mask = torch.zeros((len(prompt_lengths), max_length), dtype=torch.int64)
 | 
			
		||||
    for i, length in enumerate(prompt_lengths):
 | 
			
		||||
        attention_mask[i, max_length - length:] = 1
 | 
			
		||||
    batch_size = len(prompt_lengths)
 | 
			
		||||
 | 
			
		||||
    range_tensor = torch.arange(max_length, device=device).expand(batch_size, max_length)
 | 
			
		||||
    prompt_lengths_tensor = torch.tensor(prompt_lengths, device=device).unsqueeze(1)
 | 
			
		||||
    attention_mask = range_tensor >= max_length - prompt_lengths_tensor
 | 
			
		||||
    attention_mask = attention_mask.to(torch.int64)
 | 
			
		||||
 | 
			
		||||
    return attention_mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -501,6 +505,8 @@ class PPModelWorker:
 | 
			
		|||
        self.print_len = {}
 | 
			
		||||
        self.is_finish = {}
 | 
			
		||||
        self.model_name = checkpoint
 | 
			
		||||
 | 
			
		||||
        self.device = f"xpu:{self.rank}"
 | 
			
		||||
        # self.layer_start = 0
 | 
			
		||||
        # self.layer_end = 0
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -543,7 +549,7 @@ class PPModelWorker:
 | 
			
		|||
        return cur_batch
 | 
			
		||||
 | 
			
		||||
    def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2):
 | 
			
		||||
        if model_type in ["baichuan", "chatglm"]:
 | 
			
		||||
        if model_type in ["baichuan", "chatglm", "mixtral"]:
 | 
			
		||||
            result = []
 | 
			
		||||
            for sub_tuple1, sub_tuple2 in zip(kv_cache_1, kv_cache_2):
 | 
			
		||||
                if sub_tuple1 is None:
 | 
			
		||||
| 
						 | 
				
			
			@ -613,7 +619,7 @@ class PPModelWorker:
 | 
			
		|||
        # logger.info(f"{self.rank} {cur_batch} {input.shape}")
 | 
			
		||||
        cur_id = cur_batch.batch_id
 | 
			
		||||
        _past_key_values = self.past_key_values_dict.get(cur_id, None)
 | 
			
		||||
        attention_mask = make_attention_mask(cur_batch.prompt_lengths).to(input.device)
 | 
			
		||||
        attention_mask = make_attention_mask(cur_batch.prompt_lengths, input.device)
 | 
			
		||||
 | 
			
		||||
        if self.rank == 0:
 | 
			
		||||
            input_ids = input
 | 
			
		||||
| 
						 | 
				
			
			@ -637,7 +643,7 @@ class PPModelWorker:
 | 
			
		|||
                tmp_past_key_values = _past_key_values
 | 
			
		||||
                _past_key_values = None
 | 
			
		||||
 | 
			
		||||
        # torch.xpu.empty_cache()
 | 
			
		||||
        torch.xpu.empty_cache()
 | 
			
		||||
        output = self.model(input_ids=input_ids,
 | 
			
		||||
                            inputs_embeds=inputs_embeds,
 | 
			
		||||
                            past_key_values=_past_key_values,
 | 
			
		||||
| 
						 | 
				
			
			@ -661,7 +667,7 @@ class PPModelWorker:
 | 
			
		|||
 | 
			
		||||
            if self.pp_config.is_tail:
 | 
			
		||||
                _pre_output = self.partial_output_dict.get(cur_id, None)
 | 
			
		||||
                tmp_output = output.logits.to(self.dtype)
 | 
			
		||||
                tmp_output = output.logits
 | 
			
		||||
                tmp_output = torch.argmax(tmp_output[:, -1:, :], dim=-1)
 | 
			
		||||
                if _pre_output is None:
 | 
			
		||||
                    _pre_output = tmp_output
 | 
			
		||||
| 
						 | 
				
			
			@ -674,7 +680,10 @@ class PPModelWorker:
 | 
			
		|||
            self.past_key_values_dict[cur_id] = _past_key_values
 | 
			
		||||
        torch.xpu.synchronize()
 | 
			
		||||
        if not self.pp_config.is_tail:
 | 
			
		||||
            return output[0].to(self.dtype), cur_batch
 | 
			
		||||
            _output = output[0]
 | 
			
		||||
            if _output.dtype != self.dtype:
 | 
			
		||||
                _output = _output.to(self.dtype)
 | 
			
		||||
            return _output, cur_batch
 | 
			
		||||
        else:
 | 
			
		||||
            if cur_batch.partial_prefilling > 0 and \
 | 
			
		||||
               cur_batch.prefilled_index == cur_batch.batch_size:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue