Fix update_kv_cache in Pipeline-Parallel-Serving for glm4-9b model (#11537)

This commit is contained in:
Xiangyu Tian 2024-07-09 14:08:04 +08:00 committed by GitHub
parent fa81dbefd3
commit a1cede926d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -509,14 +509,14 @@ class ModelRunner:
return kv_cache_1
def update_kv_cache(self, kv_cache, cur_id):
def update_kv_cache(self, kv_cache, prefill=False):
layer_start = self.model.layer_start
layer_end = self.model.layer_end
num_layers = self.model.num_layers
if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
# for glm-4-9b-chat
if self.past_key_values_dict.get(cur_id, None) is None:
if prefill:
value_placeholder = torch.empty_like((kv_cache)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
@ -528,13 +528,10 @@ class ModelRunner:
pass
elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
value_placeholder = torch.empty_like((kv_cache)[-1][0])
kv_cache = tuple((value_placeholder, value_placeholder)) + \
tuple(None for _ in range(layer_start)) + \
(kv_cache)[layer_start:]
# past_key_values_placeholder = tuple(
# (value_placeholder, value_placeholder) for _ in range(layer_start)
# ) + (kv_cache)[layer_start:]
# kv_cache = past_key_values_placeholder
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
) + (kv_cache)[layer_start:]
kv_cache = past_key_values_placeholder
else:
pass
@ -590,7 +587,7 @@ class ModelRunner:
# torch.xpu.empty_cache()
if cur_batch.prefilled_index == cur_batch.batch_size:
tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, cur_id)
tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, True)
self.past_key_values_dict[cur_id] = tmp_past_key_values
@ -604,7 +601,8 @@ class ModelRunner:
_pre_output = torch.cat((_pre_output, tmp_output), dim=0)
self.partial_output_dict[cur_id] = _pre_output
else:
_past_key_values = self.update_kv_cache(output.past_key_values, cur_id)
_prefill = self.past_key_values_dict.get(cur_id, None) is None
_past_key_values = self.update_kv_cache(output.past_key_values, prefill=_prefill)
self.past_key_values_dict[cur_id] = _past_key_values
torch.xpu.synchronize()
if not self.pp_config.is_tail:
@ -687,7 +685,6 @@ class ModelRunner:
if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None):
cur_id = cur_batch.batch_id
# cur_batch = self.prepare_batch(cur_batch)
if cur_batch.prefilled_index >= cur_batch.batch_size:
cur_batch.partial_prefilling = 0
if cur_batch.partial_prefilling > 0:
@ -810,14 +807,9 @@ class ModelRunner:
dist.recv(cur_input, src=self.pre_rank)
output, cur_batch = self.model_step(cur_input, cur_batch)
# if output is not None and self.rank == self.world_size - 1:
# output = torch.argmax(output[:, -1:, :], dim=-1)
if output is not None:
# dist.send(output, dst=self.next_rank)
self.send_buff = output
else:
self.send_buff = None
if self.rank == 0:
self.on_going_batches[:-1] = self.on_going_batches[1:]
self.on_going_batches[self.world_size - 1] = cur_batch