LLM: Convert draft_model kv_cache from bf16 to fp32 (#9964)
* convert bf16 to fp32 * update * change when init * init first and cut off after * init and exchange * update python type * update * fix bug * update * update
This commit is contained in:
parent
51aa8b62b2
commit
9bff84e6fd
1 changed files with 91 additions and 1 deletions
|
|
@ -225,6 +225,7 @@ def speculative_generate(self,
|
|||
draft_generate_ids = torch.empty([input_ids.size(0), draft_gen_length],
|
||||
dtype=torch.long, device=self.device)
|
||||
past_key_values = None
|
||||
past_key_values1 = []
|
||||
|
||||
tmp_matchness = 0
|
||||
e2e_tic = 0.0
|
||||
|
|
@ -271,7 +272,71 @@ def speculative_generate(self,
|
|||
else:
|
||||
draft_current_input_ids = current_input_ids
|
||||
# Target model KV cache to draft model
|
||||
draft_past_key_values = past_key_values
|
||||
|
||||
# init draft_self_past_key_values:past_key_values1 and assign initial fp32 value
|
||||
if self.device.type == 'cpu' and step == 1:
|
||||
for i in range(len(past_key_values)):
|
||||
len0 = past_key_values[i][0].size(0)
|
||||
len1 = past_key_values[i][0].size(1)
|
||||
len2 = past_key_values[i][0].size(2)
|
||||
len3 = past_key_values[i][0].size(3)
|
||||
if self.config.model_type == "qwen":
|
||||
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
k0 = k0.transpose(1, 2)
|
||||
v0 = v0.transpose(1, 2)
|
||||
past_key_values1.append((k0, v0))
|
||||
past_key_values1[i][0][:, :len1, :, :] = past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values1[i][1][:, :len1, :, :] = past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
elif self.config.model_type == "chatglm":
|
||||
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
k0 = k0.permute(2, 0, 1, 3)
|
||||
v0 = v0.permute(2, 0, 1, 3)
|
||||
past_key_values1.append((k0, v0))
|
||||
past_key_values1[i][0][:len0, :, :, :] = past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values1[i][1][:len0, :, :, :] = past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
else:
|
||||
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
past_key_values1.append((k0, v0))
|
||||
past_key_values1[i][0][:, :, :len2, :] = past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values1[i][1][:, :, :len2, :] = past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
|
||||
# each iter cut off cur_len kv_cache from past_key_values1
|
||||
if self.device.type == 'cpu':
|
||||
tmp_past_key_values = []
|
||||
for i in range(len(past_key_values)):
|
||||
if self.config.model_type == "qwen":
|
||||
len1 = past_key_values[0][0].size(1)
|
||||
k0 = past_key_values1[i][0][:, :len1, :, :]
|
||||
v0 = past_key_values1[i][1][:, :len1, :, :]
|
||||
tmp_past_key_values.append((k0, v0))
|
||||
elif self.config.model_type == "chatglm":
|
||||
len0 = past_key_values[0][0].size(0)
|
||||
k0 = past_key_values1[i][0][:len0, :, :, :]
|
||||
v0 = past_key_values1[i][1][:len0, :, :, :]
|
||||
tmp_past_key_values.append((k0, v0))
|
||||
else:
|
||||
len2 = past_key_values[0][0].size(2)
|
||||
k0 = past_key_values1[i][0][:, :, :len2, :]
|
||||
v0 = past_key_values1[i][1][:, :, :len2, :]
|
||||
tmp_past_key_values.append((k0, v0))
|
||||
draft_past_key_values = tmp_past_key_values
|
||||
else:
|
||||
draft_past_key_values = past_key_values
|
||||
draft_generate_ids[:, 0] = current_input_ids
|
||||
tic = time.time()
|
||||
# Draft model auto-regressively generate k tokens
|
||||
|
|
@ -392,6 +457,31 @@ def speculative_generate(self,
|
|||
v[:, :, :-(max_of_max_matched - max_matched)]) for k, v in past_key_values
|
||||
]
|
||||
|
||||
# Each iter assign new_matched kv_cache to past_key_values1
|
||||
if self.device.type == 'cpu':
|
||||
for i in range(len(past_key_values)):
|
||||
if self.config.model_type == "qwen":
|
||||
size = tmp_past_key_values[i][0].size(1)
|
||||
size1 = past_key_values[i][0].size(1)
|
||||
past_key_values1[i][0][:, size:size1, :, :] = \
|
||||
past_key_values[i][0][:, size:size1, :, :].to(torch.float32)
|
||||
past_key_values1[i][1][:, size:size1, :, :] = \
|
||||
past_key_values[i][1][:, size:size1, :, :].to(torch.float32)
|
||||
elif self.config.model_type == "chatglm":
|
||||
size = tmp_past_key_values[i][0].size(0)
|
||||
size1 = past_key_values[i][0].size(0)
|
||||
past_key_values1[i][0][size:size1, :, :, :] = \
|
||||
past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
|
||||
past_key_values1[i][1][size:size1, :, :, :] = \
|
||||
past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
|
||||
else:
|
||||
size = tmp_past_key_values[i][0].size(2)
|
||||
size1 = past_key_values[i][0].size(2)
|
||||
past_key_values1[i][0][:, :, size:size1, :] = \
|
||||
past_key_values[i][0][:, :, size:size1, :].to(torch.float32)
|
||||
past_key_values1[i][1][:, :, size:size1, :] = \
|
||||
past_key_values[i][1][:, :, size:size1, :].to(torch.float32)
|
||||
|
||||
generate_ids[:, step:step+output_ids.size(1)] = output_ids
|
||||
current_input_ids = output_ids[:, -1:]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue