Fix gptj kvcache & position id (#10141)

This commit is contained in:
Yina Chen 2024-02-18 10:02:49 +08:00 committed by GitHub
parent 7400401706
commit 1508d6b089

View file

@ -468,7 +468,7 @@ def speculative_generate(self,
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
forward_args["position_ids"] = position_ids
elif self.config.model_type == "gptj":
past_length = draft_past_key_values[0][0].size(1)
past_length = draft_past_key_values[0][0].size(2)
position_ids = torch.Tensor([[past_length]]).long().to(self.device)
forward_args["position_ids"] = position_ids
draft_output = draft_model(**forward_args)
@ -563,7 +563,7 @@ def speculative_generate(self,
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
forward_args["position_ids"] = position_ids
elif self.config.model_type == "gptj":
past_length = past_key_values[0][0].size(1)
past_length = past_key_values[0][0].size(2)
input_len = drafted_input_ids.shape[1]
position_ids = torch.arange(past_length, input_len + past_length,
dtype=torch.long, device=drafted_input_ids.device)
@ -644,7 +644,7 @@ def speculative_generate(self,
past_key_values = [[tmp, key_cache, value_cache, beam_idx]
for _, key_cache, value_cache, beam_idx in past_key_values]
else:
if self.config.model_type in ["qwen", "gptj"]:
if self.config.model_type in ["qwen"]:
past_key_values = [
(k[:, :-(max_of_max_matched - max_matched), :],
v[:, :-(max_of_max_matched - max_matched), :])
@ -657,7 +657,7 @@ def speculative_generate(self,
v[:-(max_of_max_matched - max_matched), :, :, :])
for k, v in past_key_values
]
elif self.config.model_type == "baichuan":
elif self.config.model_type in ["baichuan", "gptj"]:
past_key_values = [
(k[:, :, :-(max_of_max_matched - max_matched), :],
v[:, :, :-(max_of_max_matched - max_matched), :])