Fix gptj kvcache & position id (#10141)
This commit is contained in:
parent
7400401706
commit
1508d6b089
1 changed files with 4 additions and 4 deletions
|
|
@ -468,7 +468,7 @@ def speculative_generate(self,
|
|||
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
||||
forward_args["position_ids"] = position_ids
|
||||
elif self.config.model_type == "gptj":
|
||||
past_length = draft_past_key_values[0][0].size(1)
|
||||
past_length = draft_past_key_values[0][0].size(2)
|
||||
position_ids = torch.Tensor([[past_length]]).long().to(self.device)
|
||||
forward_args["position_ids"] = position_ids
|
||||
draft_output = draft_model(**forward_args)
|
||||
|
|
@ -563,7 +563,7 @@ def speculative_generate(self,
|
|||
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
|
||||
forward_args["position_ids"] = position_ids
|
||||
elif self.config.model_type == "gptj":
|
||||
past_length = past_key_values[0][0].size(1)
|
||||
past_length = past_key_values[0][0].size(2)
|
||||
input_len = drafted_input_ids.shape[1]
|
||||
position_ids = torch.arange(past_length, input_len + past_length,
|
||||
dtype=torch.long, device=drafted_input_ids.device)
|
||||
|
|
@ -644,7 +644,7 @@ def speculative_generate(self,
|
|||
past_key_values = [[tmp, key_cache, value_cache, beam_idx]
|
||||
for _, key_cache, value_cache, beam_idx in past_key_values]
|
||||
else:
|
||||
if self.config.model_type in ["qwen", "gptj"]:
|
||||
if self.config.model_type in ["qwen"]:
|
||||
past_key_values = [
|
||||
(k[:, :-(max_of_max_matched - max_matched), :],
|
||||
v[:, :-(max_of_max_matched - max_matched), :])
|
||||
|
|
@ -657,7 +657,7 @@ def speculative_generate(self,
|
|||
v[:-(max_of_max_matched - max_matched), :, :, :])
|
||||
for k, v in past_key_values
|
||||
]
|
||||
elif self.config.model_type == "baichuan":
|
||||
elif self.config.model_type in ["baichuan", "gptj"]:
|
||||
past_key_values = [
|
||||
(k[:, :, :-(max_of_max_matched - max_matched), :],
|
||||
v[:, :, :-(max_of_max_matched - max_matched), :])
|
||||
|
|
|
|||
Loading…
Reference in a new issue