[WIP] Add look up table in 1st token stage (#11193)

* lookuptb
This commit is contained in:
Zhao Changmin 2024-06-07 10:51:05 +08:00 committed by GitHub
parent 375174af33
commit b7948671de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 48 additions and 18 deletions

View file

@ -99,6 +99,10 @@ def generate(
GenerationMixin.generate = generate GenerationMixin.generate = generate
def tensor2key(key_tensor: torch.LongTensor):
return tuple(key_tensor.tolist())
# This class is copied from https://github.com/huggingface/transformers/blob/main/src # This class is copied from https://github.com/huggingface/transformers/blob/main/src
# /transformers/generation/candidate_generator.py # /transformers/generation/candidate_generator.py
class PromptLookupCandidateGenerator(): class PromptLookupCandidateGenerator():
@ -133,9 +137,34 @@ class PromptLookupCandidateGenerator():
self.max_candidates = 9 self.max_candidates = 9
self.min_candidates = 0 self.min_candidates = 0
self.lookup_table = {}
invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0, invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0,
"Invalid max_matching_ngram_size or num_output_tokens") "Invalid max_matching_ngram_size or num_output_tokens")
def init_look_up_table(self,
input_ids: torch.LongTensor):
for ngram_size in range(self.max_matching_ngram_size, 0, -1):
# Create sliding windows of size ngram_size
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
for idx in range(windows.size(1)):
window = tensor2key(windows[0, idx])
if window not in self.lookup_table:
self.lookup_table[window] = idx
def update_look_up_table(self,
new_input_ids: torch.LongTensor):
# Maintain a look up table
window = tensor2key(new_input_ids[0, -self.max_matching_ngram_size:])
for ngram_size in range(self.max_matching_ngram_size):
if window[ngram_size:] not in self.lookup_table:
self.lookup_table[window[ngram_size:]] = \
new_input_ids.size(1)-self.max_matching_ngram_size+ngram_size
def get_n_gram_idx(self,
ngram_tensor: torch.LongTensor):
key = tensor2key(ngram_tensor)
return self.lookup_table[key]
def get_candidates(self, def get_candidates(self,
input_ids: torch.LongTensor)-> Tuple[torch.LongTensor, input_ids: torch.LongTensor)-> Tuple[torch.LongTensor,
Optional[torch.FloatTensor]]: Optional[torch.FloatTensor]]:
@ -156,31 +185,20 @@ class PromptLookupCandidateGenerator():
input_length = input_ids.size(1) input_length = input_ids.size(1)
chosen_ids = None chosen_ids = None
match_found = False
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1): for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
# Create sliding windows of size ngram_size
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
# Convert ngram to a tensor for comparison # Convert ngram to a tensor for comparison
ngram_tensor = input_ids[0, -ngram_size:] ngram_tensor = input_ids[0, -ngram_size:]
# Find where the windows match the ngram # # Get the indices of matches
matches = (windows == ngram_tensor).all(dim=2) idx = self.get_n_gram_idx(ngram_tensor)
# Get the indices of matches
match_indices = matches.nonzero(as_tuple=True)[1]
# Iterate through match indices to find a valid continuation # Iterate through match indices to find a valid continuation
for idx in match_indices: start_idx = idx + ngram_size
start_idx = idx + ngram_size end_idx = start_idx + self.num_output_tokens
end_idx = start_idx + self.num_output_tokens end_idx = min(end_idx, input_length)
end_idx = min(end_idx, input_length)
if start_idx < end_idx: if start_idx < end_idx:
chosen_ids = input_ids[0, start_idx:end_idx] chosen_ids = input_ids[0, start_idx:end_idx]
match_found = True
break
if match_found:
break break
if chosen_ids is None or len(chosen_ids) == 0: if chosen_ids is None or len(chosen_ids) == 0:
@ -267,6 +285,9 @@ def lookup_generate(self,
else: else:
output_ids = greedy(logits) output_ids = greedy(logits)
input_ids = torch.cat((input_ids, output_ids), dim=-1) input_ids = torch.cat((input_ids, output_ids), dim=-1)
candidates_generator.init_look_up_table(input_ids)
past_key_values = output['past_key_values'] past_key_values = output['past_key_values']
step += 1 step += 1
if self.device.type == 'xpu': if self.device.type == 'xpu':
@ -319,9 +340,13 @@ def lookup_generate(self,
# Drafts start from [1, k] # Drafts start from [1, k]
# Verified output start from [0, k - 1] # Verified output start from [0, k - 1]
# including the one generated by the base model # including the one generated by the base model
n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:]) n_matches = ((output_ids[:, :-1] != verify_input_ids[:, 1:])
.cumsum(-1) == 0).sum(-1).item() .cumsum(-1) == 0).sum(-1).item()
max_matched = n_matches + 1 max_matched = n_matches + 1
mot = time.time()
self.match_time.append(mot-toc)
max_of_max_matched = output_ids.size(1) max_of_max_matched = output_ids.size(1)
# Accept number is max_matched, min is 1 # Accept number is max_matched, min is 1
@ -343,9 +368,12 @@ def lookup_generate(self,
accept_rate) accept_rate)
input_ids = torch.cat((input_ids, output_ids), dim=-1) input_ids = torch.cat((input_ids, output_ids), dim=-1)
candidates_generator.update_look_up_table(input_ids)
step += output_ids.size(1) step += output_ids.size(1)
step_verify += 1 step_verify += 1
pot = time.time()
self.post_time.append(pot-mot)
# Stop on eos and remove content after eos # Stop on eos and remove content after eos
output_ids_list = output_ids[0].tolist() output_ids_list = output_ids[0].tolist()

View file

@ -162,6 +162,8 @@ def clear_benchmarks(self):
self.generate_time = [] self.generate_time = []
self.draft_time = [] self.draft_time = []
self.verify_time = [] self.verify_time = []
self.match_time = []
self.post_time = []
self.draft_num = [] self.draft_num = []
self.accept_num = [] self.accept_num = []
self.n_drafted = 0 self.n_drafted = 0