Fix update_kv_cache in Pipeline-Parallel-Serving for glm4-9b model (#11537)
This commit is contained in:
		
							parent
							
								
									fa81dbefd3
								
							
						
					
					
						commit
						a1cede926d
					
				
					 1 changed files with 11 additions and 19 deletions
				
			
		| 
						 | 
				
			
			@ -509,14 +509,14 @@ class ModelRunner:
 | 
			
		|||
 | 
			
		||||
            return kv_cache_1
 | 
			
		||||
 | 
			
		||||
    def update_kv_cache(self, kv_cache, cur_id):
 | 
			
		||||
    def update_kv_cache(self, kv_cache, prefill=False):
 | 
			
		||||
        layer_start = self.model.layer_start
 | 
			
		||||
        layer_end = self.model.layer_end
 | 
			
		||||
        num_layers = self.model.num_layers
 | 
			
		||||
 | 
			
		||||
        if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
 | 
			
		||||
            # for glm-4-9b-chat
 | 
			
		||||
            if self.past_key_values_dict.get(cur_id, None) is None:
 | 
			
		||||
            if prefill:
 | 
			
		||||
                value_placeholder = torch.empty_like((kv_cache)[-1][0])
 | 
			
		||||
                past_key_values_placeholder = tuple(
 | 
			
		||||
                    (value_placeholder, value_placeholder) for _ in range(layer_start)
 | 
			
		||||
| 
						 | 
				
			
			@ -528,13 +528,10 @@ class ModelRunner:
 | 
			
		|||
                pass
 | 
			
		||||
        elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
 | 
			
		||||
            value_placeholder = torch.empty_like((kv_cache)[-1][0])
 | 
			
		||||
            kv_cache = tuple((value_placeholder, value_placeholder)) + \
 | 
			
		||||
                tuple(None for _ in range(layer_start)) + \
 | 
			
		||||
                (kv_cache)[layer_start:]
 | 
			
		||||
            # past_key_values_placeholder = tuple(
 | 
			
		||||
            #     (value_placeholder, value_placeholder) for _ in range(layer_start)
 | 
			
		||||
            # ) + (kv_cache)[layer_start:]
 | 
			
		||||
            # kv_cache = past_key_values_placeholder
 | 
			
		||||
            past_key_values_placeholder = tuple(
 | 
			
		||||
                (value_placeholder, value_placeholder) for _ in range(layer_start)
 | 
			
		||||
            ) + (kv_cache)[layer_start:]
 | 
			
		||||
            kv_cache = past_key_values_placeholder
 | 
			
		||||
        else:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -590,7 +587,7 @@ class ModelRunner:
 | 
			
		|||
                # torch.xpu.empty_cache()
 | 
			
		||||
 | 
			
		||||
            if cur_batch.prefilled_index == cur_batch.batch_size:
 | 
			
		||||
                tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, cur_id)
 | 
			
		||||
                tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, True)
 | 
			
		||||
 | 
			
		||||
            self.past_key_values_dict[cur_id] = tmp_past_key_values
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -604,7 +601,8 @@ class ModelRunner:
 | 
			
		|||
                    _pre_output = torch.cat((_pre_output, tmp_output), dim=0)
 | 
			
		||||
                self.partial_output_dict[cur_id] = _pre_output
 | 
			
		||||
        else:
 | 
			
		||||
            _past_key_values = self.update_kv_cache(output.past_key_values, cur_id)
 | 
			
		||||
            _prefill = self.past_key_values_dict.get(cur_id, None) is None
 | 
			
		||||
            _past_key_values = self.update_kv_cache(output.past_key_values, prefill=_prefill)
 | 
			
		||||
            self.past_key_values_dict[cur_id] = _past_key_values
 | 
			
		||||
        torch.xpu.synchronize()
 | 
			
		||||
        if not self.pp_config.is_tail:
 | 
			
		||||
| 
						 | 
				
			
			@ -687,7 +685,6 @@ class ModelRunner:
 | 
			
		|||
 | 
			
		||||
            if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None):
 | 
			
		||||
                cur_id = cur_batch.batch_id
 | 
			
		||||
                # cur_batch = self.prepare_batch(cur_batch)
 | 
			
		||||
                if cur_batch.prefilled_index >= cur_batch.batch_size:
 | 
			
		||||
                    cur_batch.partial_prefilling = 0
 | 
			
		||||
                if cur_batch.partial_prefilling > 0:
 | 
			
		||||
| 
						 | 
				
			
			@ -810,14 +807,9 @@ class ModelRunner:
 | 
			
		|||
                    dist.recv(cur_input, src=self.pre_rank)
 | 
			
		||||
 | 
			
		||||
        output, cur_batch = self.model_step(cur_input, cur_batch)
 | 
			
		||||
        # if output is not None and self.rank == self.world_size - 1:
 | 
			
		||||
        #     output = torch.argmax(output[:, -1:, :], dim=-1)
 | 
			
		||||
 | 
			
		||||
        if output is not None:
 | 
			
		||||
            # dist.send(output, dst=self.next_rank)
 | 
			
		||||
            self.send_buff = output
 | 
			
		||||
        else:
 | 
			
		||||
            self.send_buff = None
 | 
			
		||||
        self.send_buff = output
 | 
			
		||||
 | 
			
		||||
        if self.rank == 0:
 | 
			
		||||
            self.on_going_batches[:-1] = self.on_going_batches[1:]
 | 
			
		||||
            self.on_going_batches[self.world_size - 1] = cur_batch
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue