diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index f292a80c..de5e5f56 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -509,14 +509,14 @@ class ModelRunner: return kv_cache_1 - def update_kv_cache(self, kv_cache, cur_id): + def update_kv_cache(self, kv_cache, prefill=False): layer_start = self.model.layer_start layer_end = self.model.layer_end num_layers = self.model.num_layers if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40: # for glm-4-9b-chat - if self.past_key_values_dict.get(cur_id, None) is None: + if prefill: value_placeholder = torch.empty_like((kv_cache)[-1][0]) past_key_values_placeholder = tuple( (value_placeholder, value_placeholder) for _ in range(layer_start) @@ -528,13 +528,10 @@ class ModelRunner: pass elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: value_placeholder = torch.empty_like((kv_cache)[-1][0]) - kv_cache = tuple((value_placeholder, value_placeholder)) + \ - tuple(None for _ in range(layer_start)) + \ - (kv_cache)[layer_start:] - # past_key_values_placeholder = tuple( - # (value_placeholder, value_placeholder) for _ in range(layer_start) - # ) + (kv_cache)[layer_start:] - # kv_cache = past_key_values_placeholder + past_key_values_placeholder = tuple( + (value_placeholder, value_placeholder) for _ in range(layer_start) + ) + (kv_cache)[layer_start:] + kv_cache = past_key_values_placeholder else: pass @@ -590,7 +587,7 @@ class ModelRunner: # torch.xpu.empty_cache() if cur_batch.prefilled_index == cur_batch.batch_size: - tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, cur_id) + tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, True) self.past_key_values_dict[cur_id] = tmp_past_key_values @@ -604,7 +601,8 @@ class ModelRunner: _pre_output = torch.cat((_pre_output, tmp_output), dim=0) self.partial_output_dict[cur_id] = _pre_output else: - _past_key_values = self.update_kv_cache(output.past_key_values, cur_id) + _prefill = self.past_key_values_dict.get(cur_id, None) is None + _past_key_values = self.update_kv_cache(output.past_key_values, prefill=_prefill) self.past_key_values_dict[cur_id] = _past_key_values torch.xpu.synchronize() if not self.pp_config.is_tail: @@ -687,7 +685,6 @@ class ModelRunner: if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None): cur_id = cur_batch.batch_id - # cur_batch = self.prepare_batch(cur_batch) if cur_batch.prefilled_index >= cur_batch.batch_size: cur_batch.partial_prefilling = 0 if cur_batch.partial_prefilling > 0: @@ -810,14 +807,9 @@ class ModelRunner: dist.recv(cur_input, src=self.pre_rank) output, cur_batch = self.model_step(cur_input, cur_batch) - # if output is not None and self.rank == self.world_size - 1: - # output = torch.argmax(output[:, -1:, :], dim=-1) - if output is not None: - # dist.send(output, dst=self.next_rank) - self.send_buff = output - else: - self.send_buff = None + self.send_buff = output + if self.rank == 0: self.on_going_batches[:-1] = self.on_going_batches[1:] self.on_going_batches[self.world_size - 1] = cur_batch