Fix update_kv_cache in Pipeline-Parallel-Serving for glm4-9b model (#11537)
This commit is contained in:
parent
fa81dbefd3
commit
a1cede926d
1 changed files with 11 additions and 19 deletions
|
|
@ -509,14 +509,14 @@ class ModelRunner:
|
||||||
|
|
||||||
return kv_cache_1
|
return kv_cache_1
|
||||||
|
|
||||||
def update_kv_cache(self, kv_cache, cur_id):
|
def update_kv_cache(self, kv_cache, prefill=False):
|
||||||
layer_start = self.model.layer_start
|
layer_start = self.model.layer_start
|
||||||
layer_end = self.model.layer_end
|
layer_end = self.model.layer_end
|
||||||
num_layers = self.model.num_layers
|
num_layers = self.model.num_layers
|
||||||
|
|
||||||
if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
|
if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
|
||||||
# for glm-4-9b-chat
|
# for glm-4-9b-chat
|
||||||
if self.past_key_values_dict.get(cur_id, None) is None:
|
if prefill:
|
||||||
value_placeholder = torch.empty_like((kv_cache)[-1][0])
|
value_placeholder = torch.empty_like((kv_cache)[-1][0])
|
||||||
past_key_values_placeholder = tuple(
|
past_key_values_placeholder = tuple(
|
||||||
(value_placeholder, value_placeholder) for _ in range(layer_start)
|
(value_placeholder, value_placeholder) for _ in range(layer_start)
|
||||||
|
|
@ -528,13 +528,10 @@ class ModelRunner:
|
||||||
pass
|
pass
|
||||||
elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
|
elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
|
||||||
value_placeholder = torch.empty_like((kv_cache)[-1][0])
|
value_placeholder = torch.empty_like((kv_cache)[-1][0])
|
||||||
kv_cache = tuple((value_placeholder, value_placeholder)) + \
|
past_key_values_placeholder = tuple(
|
||||||
tuple(None for _ in range(layer_start)) + \
|
(value_placeholder, value_placeholder) for _ in range(layer_start)
|
||||||
(kv_cache)[layer_start:]
|
) + (kv_cache)[layer_start:]
|
||||||
# past_key_values_placeholder = tuple(
|
kv_cache = past_key_values_placeholder
|
||||||
# (value_placeholder, value_placeholder) for _ in range(layer_start)
|
|
||||||
# ) + (kv_cache)[layer_start:]
|
|
||||||
# kv_cache = past_key_values_placeholder
|
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -590,7 +587,7 @@ class ModelRunner:
|
||||||
# torch.xpu.empty_cache()
|
# torch.xpu.empty_cache()
|
||||||
|
|
||||||
if cur_batch.prefilled_index == cur_batch.batch_size:
|
if cur_batch.prefilled_index == cur_batch.batch_size:
|
||||||
tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, cur_id)
|
tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, True)
|
||||||
|
|
||||||
self.past_key_values_dict[cur_id] = tmp_past_key_values
|
self.past_key_values_dict[cur_id] = tmp_past_key_values
|
||||||
|
|
||||||
|
|
@ -604,7 +601,8 @@ class ModelRunner:
|
||||||
_pre_output = torch.cat((_pre_output, tmp_output), dim=0)
|
_pre_output = torch.cat((_pre_output, tmp_output), dim=0)
|
||||||
self.partial_output_dict[cur_id] = _pre_output
|
self.partial_output_dict[cur_id] = _pre_output
|
||||||
else:
|
else:
|
||||||
_past_key_values = self.update_kv_cache(output.past_key_values, cur_id)
|
_prefill = self.past_key_values_dict.get(cur_id, None) is None
|
||||||
|
_past_key_values = self.update_kv_cache(output.past_key_values, prefill=_prefill)
|
||||||
self.past_key_values_dict[cur_id] = _past_key_values
|
self.past_key_values_dict[cur_id] = _past_key_values
|
||||||
torch.xpu.synchronize()
|
torch.xpu.synchronize()
|
||||||
if not self.pp_config.is_tail:
|
if not self.pp_config.is_tail:
|
||||||
|
|
@ -687,7 +685,6 @@ class ModelRunner:
|
||||||
|
|
||||||
if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None):
|
if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None):
|
||||||
cur_id = cur_batch.batch_id
|
cur_id = cur_batch.batch_id
|
||||||
# cur_batch = self.prepare_batch(cur_batch)
|
|
||||||
if cur_batch.prefilled_index >= cur_batch.batch_size:
|
if cur_batch.prefilled_index >= cur_batch.batch_size:
|
||||||
cur_batch.partial_prefilling = 0
|
cur_batch.partial_prefilling = 0
|
||||||
if cur_batch.partial_prefilling > 0:
|
if cur_batch.partial_prefilling > 0:
|
||||||
|
|
@ -810,14 +807,9 @@ class ModelRunner:
|
||||||
dist.recv(cur_input, src=self.pre_rank)
|
dist.recv(cur_input, src=self.pre_rank)
|
||||||
|
|
||||||
output, cur_batch = self.model_step(cur_input, cur_batch)
|
output, cur_batch = self.model_step(cur_input, cur_batch)
|
||||||
# if output is not None and self.rank == self.world_size - 1:
|
|
||||||
# output = torch.argmax(output[:, -1:, :], dim=-1)
|
|
||||||
|
|
||||||
if output is not None:
|
self.send_buff = output
|
||||||
# dist.send(output, dst=self.next_rank)
|
|
||||||
self.send_buff = output
|
|
||||||
else:
|
|
||||||
self.send_buff = None
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
self.on_going_batches[:-1] = self.on_going_batches[1:]
|
self.on_going_batches[:-1] = self.on_going_batches[1:]
|
||||||
self.on_going_batches[self.world_size - 1] = cur_batch
|
self.on_going_batches[self.world_size - 1] = cur_batch
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue