[WIP] Add look up table in 1st token stage (#11193)

* lookuptb
This commit is contained in:
Zhao Changmin 2024-06-07 10:51:05 +08:00 committed by GitHub
parent 375174af33
commit b7948671de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 48 additions and 18 deletions

View file

@ -99,6 +99,10 @@ def generate(
GenerationMixin.generate = generate
def tensor2key(key_tensor: torch.LongTensor):
return tuple(key_tensor.tolist())
# This class is copied from https://github.com/huggingface/transformers/blob/main/src
# /transformers/generation/candidate_generator.py
class PromptLookupCandidateGenerator():
@ -133,9 +137,34 @@ class PromptLookupCandidateGenerator():
self.max_candidates = 9
self.min_candidates = 0
self.lookup_table = {}
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
"Invalid max_matching_ngram_size or num_output_tokens")
def init_look_up_table(self,
input_ids: torch.LongTensor):
for ngram_size in range(self.max_matching_ngram_size, 0, -1):
# Create sliding windows of size ngram_size
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
for idx in range(windows.size(1)):
window = tensor2key(windows[0, idx])
if window not in self.lookup_table:
self.lookup_table[window] = idx
def update_look_up_table(self,
new_input_ids: torch.LongTensor):
# Maintain a look up table
window = tensor2key(new_input_ids[0, -self.max_matching_ngram_size:])
for ngram_size in range(self.max_matching_ngram_size):
if window[ngram_size:] not in self.lookup_table:
self.lookup_table[window[ngram_size:]] = \
new_input_ids.size(1)-self.max_matching_ngram_size+ngram_size
def get_n_gram_idx(self,
ngram_tensor: torch.LongTensor):
key = tensor2key(ngram_tensor)
return self.lookup_table[key]
def get_candidates(self,
input_ids: torch.LongTensor)-> Tuple[torch.LongTensor,
Optional[torch.FloatTensor]]:
@ -156,31 +185,20 @@ class PromptLookupCandidateGenerator():
input_length = input_ids.size(1)
chosen_ids = None
match_found = False
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
# Create sliding windows of size ngram_size
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
# Convert ngram to a tensor for comparison
ngram_tensor = input_ids[0, -ngram_size:]
# Find where the windows match the ngram
matches = (windows == ngram_tensor).all(dim=2)
# Get the indices of matches
match_indices = matches.nonzero(as_tuple=True)[1]
# # Get the indices of matches
idx = self.get_n_gram_idx(ngram_tensor)
# Iterate through match indices to find a valid continuation
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + self.num_output_tokens
end_idx = min(end_idx, input_length)
if start_idx < end_idx:
chosen_ids = input_ids[0, start_idx:end_idx]
match_found = True
break
if match_found:
break
if chosen_ids is None or len(chosen_ids) == 0:
@ -267,6 +285,9 @@ def lookup_generate(self,
else:
output_ids = greedy(logits)
input_ids = torch.cat((input_ids, output_ids), dim=-1)
candidates_generator.init_look_up_table(input_ids)
past_key_values = output['past_key_values']
step += 1
if self.device.type == 'xpu':
@ -319,9 +340,13 @@ def lookup_generate(self,
# Drafts start from [1, k]
# Verified output start from [0, k - 1]
# including the one generated by the base model
n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:])
.cumsum(-1) == 0).sum(-1).item()
max_matched = n_matches + 1
mot = time.time()
self.match_time.append(mot-toc)
max_of_max_matched = output_ids.size(1)
# Accept number is max_matched, min is 1
@ -343,9 +368,12 @@ def lookup_generate(self,
accept_rate)
input_ids = torch.cat((input_ids, output_ids), dim=-1)
candidates_generator.update_look_up_table(input_ids)
step += output_ids.size(1)
step_verify += 1
pot = time.time()
self.post_time.append(pot-mot)
# Stop on eos and remove content after eos
output_ids_list = output_ids[0].tolist()

View file

@ -162,6 +162,8 @@ def clear_benchmarks(self):
self.generate_time = []
self.draft_time = []
self.verify_time = []
self.match_time = []
self.post_time = []
self.draft_num = []
self.accept_num = []
self.n_drafted = 0