parent
375174af33
commit
b7948671de
2 changed files with 48 additions and 18 deletions
|
|
@ -99,6 +99,10 @@ def generate(
|
||||||
GenerationMixin.generate = generate
|
GenerationMixin.generate = generate
|
||||||
|
|
||||||
|
|
||||||
|
def tensor2key(key_tensor: torch.LongTensor):
|
||||||
|
return tuple(key_tensor.tolist())
|
||||||
|
|
||||||
|
|
||||||
# This class is copied from https://github.com/huggingface/transformers/blob/main/src
|
# This class is copied from https://github.com/huggingface/transformers/blob/main/src
|
||||||
# /transformers/generation/candidate_generator.py
|
# /transformers/generation/candidate_generator.py
|
||||||
class PromptLookupCandidateGenerator():
|
class PromptLookupCandidateGenerator():
|
||||||
|
|
@ -133,9 +137,34 @@ class PromptLookupCandidateGenerator():
|
||||||
self.max_candidates = 9
|
self.max_candidates = 9
|
||||||
self.min_candidates = 0
|
self.min_candidates = 0
|
||||||
|
|
||||||
|
self.lookup_table = {}
|
||||||
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
|
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
|
||||||
"Invalid max_matching_ngram_size or num_output_tokens")
|
"Invalid max_matching_ngram_size or num_output_tokens")
|
||||||
|
|
||||||
|
def init_look_up_table(self,
|
||||||
|
input_ids: torch.LongTensor):
|
||||||
|
for ngram_size in range(self.max_matching_ngram_size, 0, -1):
|
||||||
|
# Create sliding windows of size ngram_size
|
||||||
|
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
|
||||||
|
for idx in range(windows.size(1)):
|
||||||
|
window = tensor2key(windows[0, idx])
|
||||||
|
if window not in self.lookup_table:
|
||||||
|
self.lookup_table[window] = idx
|
||||||
|
|
||||||
|
def update_look_up_table(self,
|
||||||
|
new_input_ids: torch.LongTensor):
|
||||||
|
# Maintain a look up table
|
||||||
|
window = tensor2key(new_input_ids[0, -self.max_matching_ngram_size:])
|
||||||
|
for ngram_size in range(self.max_matching_ngram_size):
|
||||||
|
if window[ngram_size:] not in self.lookup_table:
|
||||||
|
self.lookup_table[window[ngram_size:]] = \
|
||||||
|
new_input_ids.size(1)-self.max_matching_ngram_size+ngram_size
|
||||||
|
|
||||||
|
def get_n_gram_idx(self,
|
||||||
|
ngram_tensor: torch.LongTensor):
|
||||||
|
key = tensor2key(ngram_tensor)
|
||||||
|
return self.lookup_table[key]
|
||||||
|
|
||||||
def get_candidates(self,
|
def get_candidates(self,
|
||||||
input_ids: torch.LongTensor)-> Tuple[torch.LongTensor,
|
input_ids: torch.LongTensor)-> Tuple[torch.LongTensor,
|
||||||
Optional[torch.FloatTensor]]:
|
Optional[torch.FloatTensor]]:
|
||||||
|
|
@ -156,31 +185,20 @@ class PromptLookupCandidateGenerator():
|
||||||
input_length = input_ids.size(1)
|
input_length = input_ids.size(1)
|
||||||
|
|
||||||
chosen_ids = None
|
chosen_ids = None
|
||||||
match_found = False
|
|
||||||
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
|
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
|
||||||
# Create sliding windows of size ngram_size
|
|
||||||
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
|
|
||||||
|
|
||||||
# Convert ngram to a tensor for comparison
|
# Convert ngram to a tensor for comparison
|
||||||
ngram_tensor = input_ids[0, -ngram_size:]
|
ngram_tensor = input_ids[0, -ngram_size:]
|
||||||
|
|
||||||
# Find where the windows match the ngram
|
# # Get the indices of matches
|
||||||
matches = (windows == ngram_tensor).all(dim=2)
|
idx = self.get_n_gram_idx(ngram_tensor)
|
||||||
|
|
||||||
# Get the indices of matches
|
|
||||||
match_indices = matches.nonzero(as_tuple=True)[1]
|
|
||||||
|
|
||||||
# Iterate through match indices to find a valid continuation
|
# Iterate through match indices to find a valid continuation
|
||||||
for idx in match_indices:
|
start_idx = idx + ngram_size
|
||||||
start_idx = idx + ngram_size
|
end_idx = start_idx + self.num_output_tokens
|
||||||
end_idx = start_idx + self.num_output_tokens
|
end_idx = min(end_idx, input_length)
|
||||||
end_idx = min(end_idx, input_length)
|
|
||||||
|
|
||||||
if start_idx < end_idx:
|
if start_idx < end_idx:
|
||||||
chosen_ids = input_ids[0, start_idx:end_idx]
|
chosen_ids = input_ids[0, start_idx:end_idx]
|
||||||
match_found = True
|
|
||||||
break
|
|
||||||
if match_found:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if chosen_ids is None or len(chosen_ids) == 0:
|
if chosen_ids is None or len(chosen_ids) == 0:
|
||||||
|
|
@ -267,6 +285,9 @@ def lookup_generate(self,
|
||||||
else:
|
else:
|
||||||
output_ids = greedy(logits)
|
output_ids = greedy(logits)
|
||||||
input_ids = torch.cat((input_ids, output_ids), dim=-1)
|
input_ids = torch.cat((input_ids, output_ids), dim=-1)
|
||||||
|
|
||||||
|
candidates_generator.init_look_up_table(input_ids)
|
||||||
|
|
||||||
past_key_values = output['past_key_values']
|
past_key_values = output['past_key_values']
|
||||||
step += 1
|
step += 1
|
||||||
if self.device.type == 'xpu':
|
if self.device.type == 'xpu':
|
||||||
|
|
@ -319,9 +340,13 @@ def lookup_generate(self,
|
||||||
# Drafts start from [1, k]
|
# Drafts start from [1, k]
|
||||||
# Verified output start from [0, k - 1]
|
# Verified output start from [0, k - 1]
|
||||||
# including the one generated by the base model
|
# including the one generated by the base model
|
||||||
|
|
||||||
n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:])
|
n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:])
|
||||||
.cumsum(-1) == 0).sum(-1).item()
|
.cumsum(-1) == 0).sum(-1).item()
|
||||||
|
|
||||||
max_matched = n_matches + 1
|
max_matched = n_matches + 1
|
||||||
|
mot = time.time()
|
||||||
|
self.match_time.append(mot-toc)
|
||||||
|
|
||||||
max_of_max_matched = output_ids.size(1)
|
max_of_max_matched = output_ids.size(1)
|
||||||
# Accept number is max_matched, min is 1
|
# Accept number is max_matched, min is 1
|
||||||
|
|
@ -343,9 +368,12 @@ def lookup_generate(self,
|
||||||
accept_rate)
|
accept_rate)
|
||||||
|
|
||||||
input_ids = torch.cat((input_ids, output_ids), dim=-1)
|
input_ids = torch.cat((input_ids, output_ids), dim=-1)
|
||||||
|
candidates_generator.update_look_up_table(input_ids)
|
||||||
|
|
||||||
step += output_ids.size(1)
|
step += output_ids.size(1)
|
||||||
step_verify += 1
|
step_verify += 1
|
||||||
|
pot = time.time()
|
||||||
|
self.post_time.append(pot-mot)
|
||||||
|
|
||||||
# Stop on eos and remove content after eos
|
# Stop on eos and remove content after eos
|
||||||
output_ids_list = output_ids[0].tolist()
|
output_ids_list = output_ids[0].tolist()
|
||||||
|
|
|
||||||
|
|
@ -162,6 +162,8 @@ def clear_benchmarks(self):
|
||||||
self.generate_time = []
|
self.generate_time = []
|
||||||
self.draft_time = []
|
self.draft_time = []
|
||||||
self.verify_time = []
|
self.verify_time = []
|
||||||
|
self.match_time = []
|
||||||
|
self.post_time = []
|
||||||
self.draft_num = []
|
self.draft_num = []
|
||||||
self.accept_num = []
|
self.accept_num = []
|
||||||
self.n_drafted = 0
|
self.n_drafted = 0
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue