LLM: Convert draft_model kv_cache from bf16 to fp32 (#9964)

* convert bf16 to fp32

* update

* change when init

* init first and cut off after

* init and exchange

* update python type

* update

* fix bug

* update

* update
This commit is contained in:
Wang, Jian4 2024-01-25 11:20:27 +08:00 committed by GitHub
parent 51aa8b62b2
commit 9bff84e6fd

View file

@ -225,6 +225,7 @@ def speculative_generate(self,
draft_generate_ids = torch.empty([input_ids.size(0), draft_gen_length],
dtype=torch.long, device=self.device)
past_key_values = None
past_key_values1 = []
tmp_matchness = 0
e2e_tic = 0.0
@ -271,7 +272,71 @@ def speculative_generate(self,
else:
draft_current_input_ids = current_input_ids
# Target model KV cache to draft model
draft_past_key_values = past_key_values
# init draft_self_past_key_values:past_key_values1 and assign initial fp32 value
if self.device.type == 'cpu' and step == 1:
for i in range(len(past_key_values)):
len0 = past_key_values[i][0].size(0)
len1 = past_key_values[i][0].size(1)
len2 = past_key_values[i][0].size(2)
len3 = past_key_values[i][0].size(3)
if self.config.model_type == "qwen":
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.transpose(1, 2)
v0 = v0.transpose(1, 2)
past_key_values1.append((k0, v0))
past_key_values1[i][0][:, :len1, :, :] = past_key_values[i][0].to(
torch.float32)
past_key_values1[i][1][:, :len1, :, :] = past_key_values[i][1].to(
torch.float32)
elif self.config.model_type == "chatglm":
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.permute(2, 0, 1, 3)
v0 = v0.permute(2, 0, 1, 3)
past_key_values1.append((k0, v0))
past_key_values1[i][0][:len0, :, :, :] = past_key_values[i][0].to(
torch.float32)
past_key_values1[i][1][:len0, :, :, :] = past_key_values[i][1].to(
torch.float32)
else:
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
past_key_values1.append((k0, v0))
past_key_values1[i][0][:, :, :len2, :] = past_key_values[i][0].to(
torch.float32)
past_key_values1[i][1][:, :, :len2, :] = past_key_values[i][1].to(
torch.float32)
# each iter cut off cur_len kv_cache from past_key_values1
if self.device.type == 'cpu':
tmp_past_key_values = []
for i in range(len(past_key_values)):
if self.config.model_type == "qwen":
len1 = past_key_values[0][0].size(1)
k0 = past_key_values1[i][0][:, :len1, :, :]
v0 = past_key_values1[i][1][:, :len1, :, :]
tmp_past_key_values.append((k0, v0))
elif self.config.model_type == "chatglm":
len0 = past_key_values[0][0].size(0)
k0 = past_key_values1[i][0][:len0, :, :, :]
v0 = past_key_values1[i][1][:len0, :, :, :]
tmp_past_key_values.append((k0, v0))
else:
len2 = past_key_values[0][0].size(2)
k0 = past_key_values1[i][0][:, :, :len2, :]
v0 = past_key_values1[i][1][:, :, :len2, :]
tmp_past_key_values.append((k0, v0))
draft_past_key_values = tmp_past_key_values
else:
draft_past_key_values = past_key_values
draft_generate_ids[:, 0] = current_input_ids
tic = time.time()
# Draft model auto-regressively generate k tokens
@ -392,6 +457,31 @@ def speculative_generate(self,
v[:, :, :-(max_of_max_matched - max_matched)]) for k, v in past_key_values
]
# Each iter assign new_matched kv_cache to past_key_values1
if self.device.type == 'cpu':
for i in range(len(past_key_values)):
if self.config.model_type == "qwen":
size = tmp_past_key_values[i][0].size(1)
size1 = past_key_values[i][0].size(1)
past_key_values1[i][0][:, size:size1, :, :] = \
past_key_values[i][0][:, size:size1, :, :].to(torch.float32)
past_key_values1[i][1][:, size:size1, :, :] = \
past_key_values[i][1][:, size:size1, :, :].to(torch.float32)
elif self.config.model_type == "chatglm":
size = tmp_past_key_values[i][0].size(0)
size1 = past_key_values[i][0].size(0)
past_key_values1[i][0][size:size1, :, :, :] = \
past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
past_key_values1[i][1][size:size1, :, :, :] = \
past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
else:
size = tmp_past_key_values[i][0].size(2)
size1 = past_key_values[i][0].size(2)
past_key_values1[i][0][:, :, size:size1, :] = \
past_key_values[i][0][:, :, size:size1, :].to(torch.float32)
past_key_values1[i][1][:, :, size:size1, :] = \
past_key_values[i][1][:, :, size:size1, :].to(torch.float32)
generate_ids[:, step:step+output_ids.size(1)] = output_ids
current_input_ids = output_ids[:, -1:]