Fix gptj kvcache & position id (#10141)
This commit is contained in:
parent
7400401706
commit
1508d6b089
1 changed files with 4 additions and 4 deletions
|
|
@ -468,7 +468,7 @@ def speculative_generate(self,
|
||||||
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
||||||
forward_args["position_ids"] = position_ids
|
forward_args["position_ids"] = position_ids
|
||||||
elif self.config.model_type == "gptj":
|
elif self.config.model_type == "gptj":
|
||||||
past_length = draft_past_key_values[0][0].size(1)
|
past_length = draft_past_key_values[0][0].size(2)
|
||||||
position_ids = torch.Tensor([[past_length]]).long().to(self.device)
|
position_ids = torch.Tensor([[past_length]]).long().to(self.device)
|
||||||
forward_args["position_ids"] = position_ids
|
forward_args["position_ids"] = position_ids
|
||||||
draft_output = draft_model(**forward_args)
|
draft_output = draft_model(**forward_args)
|
||||||
|
|
@ -563,7 +563,7 @@ def speculative_generate(self,
|
||||||
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
|
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
|
||||||
forward_args["position_ids"] = position_ids
|
forward_args["position_ids"] = position_ids
|
||||||
elif self.config.model_type == "gptj":
|
elif self.config.model_type == "gptj":
|
||||||
past_length = past_key_values[0][0].size(1)
|
past_length = past_key_values[0][0].size(2)
|
||||||
input_len = drafted_input_ids.shape[1]
|
input_len = drafted_input_ids.shape[1]
|
||||||
position_ids = torch.arange(past_length, input_len + past_length,
|
position_ids = torch.arange(past_length, input_len + past_length,
|
||||||
dtype=torch.long, device=drafted_input_ids.device)
|
dtype=torch.long, device=drafted_input_ids.device)
|
||||||
|
|
@ -644,7 +644,7 @@ def speculative_generate(self,
|
||||||
past_key_values = [[tmp, key_cache, value_cache, beam_idx]
|
past_key_values = [[tmp, key_cache, value_cache, beam_idx]
|
||||||
for _, key_cache, value_cache, beam_idx in past_key_values]
|
for _, key_cache, value_cache, beam_idx in past_key_values]
|
||||||
else:
|
else:
|
||||||
if self.config.model_type in ["qwen", "gptj"]:
|
if self.config.model_type in ["qwen"]:
|
||||||
past_key_values = [
|
past_key_values = [
|
||||||
(k[:, :-(max_of_max_matched - max_matched), :],
|
(k[:, :-(max_of_max_matched - max_matched), :],
|
||||||
v[:, :-(max_of_max_matched - max_matched), :])
|
v[:, :-(max_of_max_matched - max_matched), :])
|
||||||
|
|
@ -657,7 +657,7 @@ def speculative_generate(self,
|
||||||
v[:-(max_of_max_matched - max_matched), :, :, :])
|
v[:-(max_of_max_matched - max_matched), :, :, :])
|
||||||
for k, v in past_key_values
|
for k, v in past_key_values
|
||||||
]
|
]
|
||||||
elif self.config.model_type == "baichuan":
|
elif self.config.model_type in ["baichuan", "gptj"]:
|
||||||
past_key_values = [
|
past_key_values = [
|
||||||
(k[:, :, :-(max_of_max_matched - max_matched), :],
|
(k[:, :, :-(max_of_max_matched - max_matched), :],
|
||||||
v[:, :, :-(max_of_max_matched - max_matched), :])
|
v[:, :, :-(max_of_max_matched - max_matched), :])
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue