LLM: Convert draft_model kv_cache from bf16 to fp32 (#9964)
* convert bf16 to fp32 * update * change when init * init first and cut off after * init and exchange * update python type * update * fix bug * update * update
This commit is contained in:
		
							parent
							
								
									51aa8b62b2
								
							
						
					
					
						commit
						9bff84e6fd
					
				
					 1 changed files with 91 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -225,6 +225,7 @@ def speculative_generate(self,
 | 
			
		|||
    draft_generate_ids = torch.empty([input_ids.size(0), draft_gen_length],
 | 
			
		||||
                                     dtype=torch.long, device=self.device)
 | 
			
		||||
    past_key_values = None
 | 
			
		||||
    past_key_values1 = []
 | 
			
		||||
 | 
			
		||||
    tmp_matchness = 0
 | 
			
		||||
    e2e_tic = 0.0
 | 
			
		||||
| 
						 | 
				
			
			@ -271,6 +272,70 @@ def speculative_generate(self,
 | 
			
		|||
        else:
 | 
			
		||||
            draft_current_input_ids = current_input_ids
 | 
			
		||||
            # Target model KV cache to draft model
 | 
			
		||||
 | 
			
		||||
            # init draft_self_past_key_values:past_key_values1 and assign initial fp32 value
 | 
			
		||||
            if self.device.type == 'cpu' and step == 1:
 | 
			
		||||
                for i in range(len(past_key_values)):
 | 
			
		||||
                    len0 = past_key_values[i][0].size(0)
 | 
			
		||||
                    len1 = past_key_values[i][0].size(1)
 | 
			
		||||
                    len2 = past_key_values[i][0].size(2)
 | 
			
		||||
                    len3 = past_key_values[i][0].size(3)
 | 
			
		||||
                    if self.config.model_type == "qwen":
 | 
			
		||||
                        k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
 | 
			
		||||
                                        dtype=torch.float32)
 | 
			
		||||
                        v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
 | 
			
		||||
                                        dtype=torch.float32)
 | 
			
		||||
                        k0 = k0.transpose(1, 2)
 | 
			
		||||
                        v0 = v0.transpose(1, 2)
 | 
			
		||||
                        past_key_values1.append((k0, v0))
 | 
			
		||||
                        past_key_values1[i][0][:, :len1, :, :] = past_key_values[i][0].to(
 | 
			
		||||
                            torch.float32)
 | 
			
		||||
                        past_key_values1[i][1][:, :len1, :, :] = past_key_values[i][1].to(
 | 
			
		||||
                            torch.float32)
 | 
			
		||||
                    elif self.config.model_type == "chatglm":
 | 
			
		||||
                        k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
 | 
			
		||||
                                        dtype=torch.float32)
 | 
			
		||||
                        v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
 | 
			
		||||
                                        dtype=torch.float32)
 | 
			
		||||
                        k0 = k0.permute(2, 0, 1, 3)
 | 
			
		||||
                        v0 = v0.permute(2, 0, 1, 3)
 | 
			
		||||
                        past_key_values1.append((k0, v0))
 | 
			
		||||
                        past_key_values1[i][0][:len0, :, :, :] = past_key_values[i][0].to(
 | 
			
		||||
                            torch.float32)
 | 
			
		||||
                        past_key_values1[i][1][:len0, :, :, :] = past_key_values[i][1].to(
 | 
			
		||||
                            torch.float32)
 | 
			
		||||
                    else:
 | 
			
		||||
                        k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
 | 
			
		||||
                                        dtype=torch.float32)
 | 
			
		||||
                        v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
 | 
			
		||||
                                        dtype=torch.float32)
 | 
			
		||||
                        past_key_values1.append((k0, v0))
 | 
			
		||||
                        past_key_values1[i][0][:, :, :len2, :] = past_key_values[i][0].to(
 | 
			
		||||
                            torch.float32)
 | 
			
		||||
                        past_key_values1[i][1][:, :, :len2, :] = past_key_values[i][1].to(
 | 
			
		||||
                            torch.float32)
 | 
			
		||||
 | 
			
		||||
            # each iter cut off cur_len kv_cache from past_key_values1
 | 
			
		||||
            if self.device.type == 'cpu':
 | 
			
		||||
                tmp_past_key_values = []
 | 
			
		||||
                for i in range(len(past_key_values)):
 | 
			
		||||
                    if self.config.model_type == "qwen":
 | 
			
		||||
                        len1 = past_key_values[0][0].size(1)
 | 
			
		||||
                        k0 = past_key_values1[i][0][:, :len1, :, :]
 | 
			
		||||
                        v0 = past_key_values1[i][1][:, :len1, :, :]
 | 
			
		||||
                        tmp_past_key_values.append((k0, v0))
 | 
			
		||||
                    elif self.config.model_type == "chatglm":
 | 
			
		||||
                        len0 = past_key_values[0][0].size(0)
 | 
			
		||||
                        k0 = past_key_values1[i][0][:len0, :, :, :]
 | 
			
		||||
                        v0 = past_key_values1[i][1][:len0, :, :, :]
 | 
			
		||||
                        tmp_past_key_values.append((k0, v0))
 | 
			
		||||
                    else:
 | 
			
		||||
                        len2 = past_key_values[0][0].size(2)
 | 
			
		||||
                        k0 = past_key_values1[i][0][:, :, :len2, :]
 | 
			
		||||
                        v0 = past_key_values1[i][1][:, :, :len2, :]
 | 
			
		||||
                        tmp_past_key_values.append((k0, v0))
 | 
			
		||||
                draft_past_key_values = tmp_past_key_values
 | 
			
		||||
            else:
 | 
			
		||||
                draft_past_key_values = past_key_values
 | 
			
		||||
            draft_generate_ids[:, 0] = current_input_ids
 | 
			
		||||
            tic = time.time()
 | 
			
		||||
| 
						 | 
				
			
			@ -392,6 +457,31 @@ def speculative_generate(self,
 | 
			
		|||
                         v[:, :, :-(max_of_max_matched - max_matched)]) for k, v in past_key_values
 | 
			
		||||
                    ]
 | 
			
		||||
 | 
			
		||||
            # Each iter assign new_matched kv_cache to past_key_values1
 | 
			
		||||
            if self.device.type == 'cpu':
 | 
			
		||||
                for i in range(len(past_key_values)):
 | 
			
		||||
                    if self.config.model_type == "qwen":
 | 
			
		||||
                        size = tmp_past_key_values[i][0].size(1)
 | 
			
		||||
                        size1 = past_key_values[i][0].size(1)
 | 
			
		||||
                        past_key_values1[i][0][:, size:size1, :, :] = \
 | 
			
		||||
                            past_key_values[i][0][:, size:size1, :, :].to(torch.float32)
 | 
			
		||||
                        past_key_values1[i][1][:, size:size1, :, :] = \
 | 
			
		||||
                            past_key_values[i][1][:, size:size1, :, :].to(torch.float32)
 | 
			
		||||
                    elif self.config.model_type == "chatglm":
 | 
			
		||||
                        size = tmp_past_key_values[i][0].size(0)
 | 
			
		||||
                        size1 = past_key_values[i][0].size(0)
 | 
			
		||||
                        past_key_values1[i][0][size:size1, :, :, :] = \
 | 
			
		||||
                            past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
 | 
			
		||||
                        past_key_values1[i][1][size:size1, :, :, :] = \
 | 
			
		||||
                            past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
 | 
			
		||||
                    else:
 | 
			
		||||
                        size = tmp_past_key_values[i][0].size(2)
 | 
			
		||||
                        size1 = past_key_values[i][0].size(2)
 | 
			
		||||
                        past_key_values1[i][0][:, :, size:size1, :] = \
 | 
			
		||||
                            past_key_values[i][0][:, :, size:size1, :].to(torch.float32)
 | 
			
		||||
                        past_key_values1[i][1][:, :, size:size1, :] = \
 | 
			
		||||
                            past_key_values[i][1][:, :, size:size1, :].to(torch.float32)
 | 
			
		||||
 | 
			
		||||
            generate_ids[:, step:step+output_ids.size(1)] = output_ids
 | 
			
		||||
            current_input_ids = output_ids[:, -1:]
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue