Fix update_kv_cache in Pipeline-Parallel-Serving for glm4-9b model (#11537)
This commit is contained in:
parent
fa81dbefd3
commit
a1cede926d
1 changed files with 11 additions and 19 deletions
|
|
@ -509,14 +509,14 @@ class ModelRunner:
|
|||
|
||||
return kv_cache_1
|
||||
|
||||
def update_kv_cache(self, kv_cache, cur_id):
|
||||
def update_kv_cache(self, kv_cache, prefill=False):
|
||||
layer_start = self.model.layer_start
|
||||
layer_end = self.model.layer_end
|
||||
num_layers = self.model.num_layers
|
||||
|
||||
if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
|
||||
# for glm-4-9b-chat
|
||||
if self.past_key_values_dict.get(cur_id, None) is None:
|
||||
if prefill:
|
||||
value_placeholder = torch.empty_like((kv_cache)[-1][0])
|
||||
past_key_values_placeholder = tuple(
|
||||
(value_placeholder, value_placeholder) for _ in range(layer_start)
|
||||
|
|
@ -528,13 +528,10 @@ class ModelRunner:
|
|||
pass
|
||||
elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
|
||||
value_placeholder = torch.empty_like((kv_cache)[-1][0])
|
||||
kv_cache = tuple((value_placeholder, value_placeholder)) + \
|
||||
tuple(None for _ in range(layer_start)) + \
|
||||
(kv_cache)[layer_start:]
|
||||
# past_key_values_placeholder = tuple(
|
||||
# (value_placeholder, value_placeholder) for _ in range(layer_start)
|
||||
# ) + (kv_cache)[layer_start:]
|
||||
# kv_cache = past_key_values_placeholder
|
||||
past_key_values_placeholder = tuple(
|
||||
(value_placeholder, value_placeholder) for _ in range(layer_start)
|
||||
) + (kv_cache)[layer_start:]
|
||||
kv_cache = past_key_values_placeholder
|
||||
else:
|
||||
pass
|
||||
|
||||
|
|
@ -590,7 +587,7 @@ class ModelRunner:
|
|||
# torch.xpu.empty_cache()
|
||||
|
||||
if cur_batch.prefilled_index == cur_batch.batch_size:
|
||||
tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, cur_id)
|
||||
tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, True)
|
||||
|
||||
self.past_key_values_dict[cur_id] = tmp_past_key_values
|
||||
|
||||
|
|
@ -604,7 +601,8 @@ class ModelRunner:
|
|||
_pre_output = torch.cat((_pre_output, tmp_output), dim=0)
|
||||
self.partial_output_dict[cur_id] = _pre_output
|
||||
else:
|
||||
_past_key_values = self.update_kv_cache(output.past_key_values, cur_id)
|
||||
_prefill = self.past_key_values_dict.get(cur_id, None) is None
|
||||
_past_key_values = self.update_kv_cache(output.past_key_values, prefill=_prefill)
|
||||
self.past_key_values_dict[cur_id] = _past_key_values
|
||||
torch.xpu.synchronize()
|
||||
if not self.pp_config.is_tail:
|
||||
|
|
@ -687,7 +685,6 @@ class ModelRunner:
|
|||
|
||||
if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None):
|
||||
cur_id = cur_batch.batch_id
|
||||
# cur_batch = self.prepare_batch(cur_batch)
|
||||
if cur_batch.prefilled_index >= cur_batch.batch_size:
|
||||
cur_batch.partial_prefilling = 0
|
||||
if cur_batch.partial_prefilling > 0:
|
||||
|
|
@ -810,14 +807,9 @@ class ModelRunner:
|
|||
dist.recv(cur_input, src=self.pre_rank)
|
||||
|
||||
output, cur_batch = self.model_step(cur_input, cur_batch)
|
||||
# if output is not None and self.rank == self.world_size - 1:
|
||||
# output = torch.argmax(output[:, -1:, :], dim=-1)
|
||||
|
||||
if output is not None:
|
||||
# dist.send(output, dst=self.next_rank)
|
||||
self.send_buff = output
|
||||
else:
|
||||
self.send_buff = None
|
||||
|
||||
if self.rank == 0:
|
||||
self.on_going_batches[:-1] = self.on_going_batches[1:]
|
||||
self.on_going_batches[self.world_size - 1] = cur_batch
|
||||
|
|
|
|||
Loading…
Reference in a new issue