parent
							
								
									c75f3dd874
								
							
						
					
					
						commit
						a0c73c26d8
					
				
					 4 changed files with 6 additions and 72 deletions
				
			
		| 
						 | 
				
			
			@ -415,11 +415,6 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        return outputs
 | 
			
		||||
 | 
			
		||||
    def post_forward(self, past_key_value, new_keys, new_values):
 | 
			
		||||
        key_value_states = []
 | 
			
		||||
        for i in range(self.intra_stages):
 | 
			
		||||
            for j in range(1, len(self.backend_decoders[i].torch_out)):
 | 
			
		||||
                key_value_states.append(self.backend_decoders[i].torch_out[j])
 | 
			
		||||
 | 
			
		||||
        cache_kwargs = {
 | 
			
		||||
            "max_seq_len": self.max_seq_len,
 | 
			
		||||
            "transpose": self.transpose_value,
 | 
			
		||||
| 
						 | 
				
			
			@ -556,7 +551,6 @@ def run_decode(
 | 
			
		|||
    head_dim = model.model.layers[layer_start].self_attn.head_dim
 | 
			
		||||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    intermediate_size = model.config.intermediate_size
 | 
			
		||||
    deocderlayers = []
 | 
			
		||||
    layer_weights = []
 | 
			
		||||
    input_layer_norm_weights = []
 | 
			
		||||
    post_attn_layernorm_weights = []
 | 
			
		||||
| 
						 | 
				
			
			@ -610,13 +604,12 @@ def run_decode(
 | 
			
		|||
    with torch.inference_mode():
 | 
			
		||||
        while True:
 | 
			
		||||
 | 
			
		||||
            dist.broadcast(control, src=0)
 | 
			
		||||
            dist.broadcast(control, src=0, async_op=False)
 | 
			
		||||
            if control.item() == -2:
 | 
			
		||||
                break
 | 
			
		||||
            elif control.item() == -1:
 | 
			
		||||
                past_key_values = input_queue.get()
 | 
			
		||||
            else:
 | 
			
		||||
                t0 = time.perf_counter()
 | 
			
		||||
                past_key_values_length = past_key_values.get_seq_length()
 | 
			
		||||
                seq_length_with_past = 1 + past_key_values_length
 | 
			
		||||
                position_ids = torch.arange(
 | 
			
		||||
| 
						 | 
				
			
			@ -636,7 +629,6 @@ def run_decode(
 | 
			
		|||
                )
 | 
			
		||||
                padded_causal_mask[:, :, :, -1] = 0.0
 | 
			
		||||
                dist.recv(hidden_states, src=rank - 1)
 | 
			
		||||
                t1 = time.perf_counter()
 | 
			
		||||
                layer_outputs = multi_decoder(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask=padded_causal_mask,
 | 
			
		||||
| 
						 | 
				
			
			@ -645,11 +637,8 @@ def run_decode(
 | 
			
		|||
                    output_attentions=False,
 | 
			
		||||
                    use_cache=True,
 | 
			
		||||
                )
 | 
			
		||||
                t2 = time.perf_counter()
 | 
			
		||||
                hidden_states = layer_outputs[0]
 | 
			
		||||
                t3 = time.perf_counter()
 | 
			
		||||
                dist.send(hidden_states, dst=(rank + 1) % world_size)
 | 
			
		||||
                t4 = time.perf_counter()
 | 
			
		||||
                past_key_values = layer_outputs[1]
 | 
			
		||||
                new_keys = layer_outputs[2]
 | 
			
		||||
                new_values = layer_outputs[3]
 | 
			
		||||
| 
						 | 
				
			
			@ -674,6 +663,7 @@ class DecodeRunner:
 | 
			
		|||
        self.input_queues = []
 | 
			
		||||
        self.output_queues = []
 | 
			
		||||
        self.decoder_processes = []
 | 
			
		||||
        self.forward_signal = torch.tensor(0, dtype=torch.int)
 | 
			
		||||
 | 
			
		||||
        for rank in range(1, world_size):
 | 
			
		||||
            input_q = mp.Queue()
 | 
			
		||||
| 
						 | 
				
			
			@ -721,21 +711,17 @@ class DecodeRunner:
 | 
			
		|||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
    ):
 | 
			
		||||
        t0 = time.perf_counter()
 | 
			
		||||
 | 
			
		||||
        if self.cache_past_key_value != past_key_value:
 | 
			
		||||
            control = torch.tensor(-1, dtype=torch.int)
 | 
			
		||||
            dist.broadcast(control, src=0)
 | 
			
		||||
            for i in range(len(self.decoder_processes)):
 | 
			
		||||
                self.input_queues[i].put(past_key_value)
 | 
			
		||||
 | 
			
		||||
        control = torch.tensor(0, dtype=torch.int)
 | 
			
		||||
        dist.broadcast(control, src=0)
 | 
			
		||||
        dist.broadcast(self.forward_signal, src=0, async_op=True)
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float16)
 | 
			
		||||
        dist.send(hidden_states, dst=1)
 | 
			
		||||
        past_key_value.expand(self.transpose_value_cache)
 | 
			
		||||
        dist.recv(hidden_states, src=self.world_size - 1)
 | 
			
		||||
        t1 = time.perf_counter()
 | 
			
		||||
        return hidden_states, past_key_value
 | 
			
		||||
 | 
			
		||||
    def shutdown(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -918,7 +904,6 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
 | 
			
		|||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
        t0 = time.perf_counter()
 | 
			
		||||
        output_attentions = (
 | 
			
		||||
            output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			@ -1026,8 +1011,6 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
 | 
			
		|||
                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
 | 
			
		||||
                if v is not None
 | 
			
		||||
            )
 | 
			
		||||
        t1 = time.perf_counter()
 | 
			
		||||
        # print("fused model forward time: ", t1 - t0)
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -330,11 +330,6 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        return outputs
 | 
			
		||||
 | 
			
		||||
    def post_forward(self, past_key_value, new_keys, new_values, cache_position):
 | 
			
		||||
        key_value_states = []
 | 
			
		||||
        for i in range(self.intra_stages):
 | 
			
		||||
            for j in range(1, len(self.backend_decoders[i].torch_out)):
 | 
			
		||||
                key_value_states.append(self.backend_decoders[i].torch_out[j])
 | 
			
		||||
 | 
			
		||||
        cache_kwargs = {
 | 
			
		||||
            "cache_position": cache_position,
 | 
			
		||||
            "max_seq_len": self.max_seq_len,
 | 
			
		||||
| 
						 | 
				
			
			@ -474,7 +469,6 @@ def run_decode(
 | 
			
		|||
    head_dim = model.model.layers[layer_start].self_attn.head_dim
 | 
			
		||||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    intermediate_size = model.config.intermediate_size
 | 
			
		||||
    deocderlayers = []
 | 
			
		||||
    layer_weights = []
 | 
			
		||||
    input_layer_norm_weights = []
 | 
			
		||||
    post_attn_layernorm_weights = []
 | 
			
		||||
| 
						 | 
				
			
			@ -536,7 +530,6 @@ def run_decode(
 | 
			
		|||
            elif control.item() == -1:
 | 
			
		||||
                past_key_values = input_queue.get()
 | 
			
		||||
            else:
 | 
			
		||||
                t0 = time.perf_counter()
 | 
			
		||||
                past_seen_tokens = past_key_values.get_seq_length()
 | 
			
		||||
                attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64)
 | 
			
		||||
                cache_position = torch.arange(
 | 
			
		||||
| 
						 | 
				
			
			@ -555,7 +548,6 @@ def run_decode(
 | 
			
		|||
                )
 | 
			
		||||
                padded_causal_mask[:, :, :, -1] = 0.0
 | 
			
		||||
                dist.recv(hidden_states, src=rank - 1)
 | 
			
		||||
                t1 = time.perf_counter()
 | 
			
		||||
                layer_outputs = multi_decoder(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask=padded_causal_mask,
 | 
			
		||||
| 
						 | 
				
			
			@ -565,11 +557,8 @@ def run_decode(
 | 
			
		|||
                    use_cache=True,
 | 
			
		||||
                    cache_position=cache_position,
 | 
			
		||||
                )
 | 
			
		||||
                t2 = time.perf_counter()
 | 
			
		||||
                hidden_states = layer_outputs[0]
 | 
			
		||||
                t3 = time.perf_counter()
 | 
			
		||||
                dist.send(hidden_states, dst=(rank + 1) % world_size)
 | 
			
		||||
                t4 = time.perf_counter()
 | 
			
		||||
                past_key_values = layer_outputs[1]
 | 
			
		||||
                new_keys = layer_outputs[2]
 | 
			
		||||
                new_values = layer_outputs[3]
 | 
			
		||||
| 
						 | 
				
			
			@ -651,14 +640,11 @@ class DecodeRunner:
 | 
			
		|||
            dist.broadcast(control, src=0)
 | 
			
		||||
            for i in range(len(self.decoder_processes)):
 | 
			
		||||
                self.input_queues[i].put(past_key_value)
 | 
			
		||||
        t0 = time.perf_counter()
 | 
			
		||||
        dist.broadcast(self.forward_signal, src=0, async_op=True)
 | 
			
		||||
        t1 = time.perf_counter()
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float16)
 | 
			
		||||
        dist.send(hidden_states, dst=1)
 | 
			
		||||
        past_key_value.expand(self.transpose_value_cache)
 | 
			
		||||
        dist.recv(hidden_states, src=self.world_size - 1)
 | 
			
		||||
        t2 = time.perf_counter()
 | 
			
		||||
        return hidden_states, past_key_value
 | 
			
		||||
 | 
			
		||||
    def shutdown(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -847,7 +833,6 @@ def gen_llama_fused_model_forward(prefill_runner, decode_runner):
 | 
			
		|||
        return_dict: Optional[bool] = None,
 | 
			
		||||
        cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
        t0 = time.perf_counter()
 | 
			
		||||
        output_attentions = (
 | 
			
		||||
            output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			@ -938,8 +923,6 @@ def gen_llama_fused_model_forward(prefill_runner, decode_runner):
 | 
			
		|||
                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
 | 
			
		||||
                if v is not None
 | 
			
		||||
            )
 | 
			
		||||
        t1 = time.perf_counter()
 | 
			
		||||
        # print("fused model forward time: ", t1 - t0)
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -354,11 +354,6 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        return outputs
 | 
			
		||||
 | 
			
		||||
    def post_forward(self, past_key_value, new_keys, new_values):
 | 
			
		||||
        key_value_states = []
 | 
			
		||||
        for i in range(self.intra_stages):
 | 
			
		||||
            for j in range(1, len(self.backend_decoders[i].torch_out)):
 | 
			
		||||
                key_value_states.append(self.backend_decoders[i].torch_out[j])
 | 
			
		||||
 | 
			
		||||
        cache_kwargs = {
 | 
			
		||||
            "max_seq_len": self.max_seq_len,
 | 
			
		||||
            "transpose": self.transpose_value,
 | 
			
		||||
| 
						 | 
				
			
			@ -501,7 +496,6 @@ def run_decode(
 | 
			
		|||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    intermediate_size = model.config.intermediate_size
 | 
			
		||||
    num_hidden_layers = model.config.num_hidden_layers
 | 
			
		||||
    deocderlayers = []
 | 
			
		||||
    layer_weights = []
 | 
			
		||||
    input_layer_norm_weights = []
 | 
			
		||||
    post_attn_layernorm_weights = []
 | 
			
		||||
| 
						 | 
				
			
			@ -559,13 +553,12 @@ def run_decode(
 | 
			
		|||
    with torch.inference_mode():
 | 
			
		||||
        while True:
 | 
			
		||||
 | 
			
		||||
            dist.broadcast(control, src=0)
 | 
			
		||||
            dist.broadcast(control, src=0, async_op=False)
 | 
			
		||||
            if control.item() == -2:
 | 
			
		||||
                break
 | 
			
		||||
            elif control.item() == -1:
 | 
			
		||||
                past_key_values = input_queue.get()
 | 
			
		||||
            else:
 | 
			
		||||
                t0 = time.perf_counter()
 | 
			
		||||
                past_seen_tokens = past_key_values.get_seq_length()
 | 
			
		||||
                attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64)
 | 
			
		||||
                cache_position = torch.arange(
 | 
			
		||||
| 
						 | 
				
			
			@ -589,7 +582,6 @@ def run_decode(
 | 
			
		|||
                )
 | 
			
		||||
                padded_causal_mask[:, :, :, -1] = 0.0
 | 
			
		||||
                dist.recv(hidden_states, src=rank - 1)
 | 
			
		||||
                t1 = time.perf_counter()
 | 
			
		||||
                layer_outputs = multi_decoder(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask=padded_causal_mask,
 | 
			
		||||
| 
						 | 
				
			
			@ -598,11 +590,8 @@ def run_decode(
 | 
			
		|||
                    output_attentions=False,
 | 
			
		||||
                    use_cache=True,
 | 
			
		||||
                )
 | 
			
		||||
                t2 = time.perf_counter()
 | 
			
		||||
                hidden_states = layer_outputs[0]
 | 
			
		||||
                t3 = time.perf_counter()
 | 
			
		||||
                dist.send(hidden_states, dst=(rank + 1) % world_size)
 | 
			
		||||
                t4 = time.perf_counter()
 | 
			
		||||
                past_key_values = layer_outputs[1]
 | 
			
		||||
                new_keys = layer_outputs[2]
 | 
			
		||||
                new_values = layer_outputs[3]
 | 
			
		||||
| 
						 | 
				
			
			@ -628,6 +617,7 @@ class DecodeRunner:
 | 
			
		|||
        self.input_queues = []
 | 
			
		||||
        self.output_queues = []
 | 
			
		||||
        self.decoder_processes = []
 | 
			
		||||
        self.forward_signal = torch.tensor(0, dtype=torch.int)
 | 
			
		||||
 | 
			
		||||
        for rank in range(1, world_size):
 | 
			
		||||
            input_q = mp.Queue()
 | 
			
		||||
| 
						 | 
				
			
			@ -677,21 +667,17 @@ class DecodeRunner:
 | 
			
		|||
        use_cache: bool = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        t0 = time.perf_counter()
 | 
			
		||||
 | 
			
		||||
        if self.cache_past_key_value != past_key_value:
 | 
			
		||||
            control = torch.tensor(-1, dtype=torch.int)
 | 
			
		||||
            dist.broadcast(control, src=0)
 | 
			
		||||
            for i in range(len(self.decoder_processes)):
 | 
			
		||||
                self.input_queues[i].put(past_key_value)
 | 
			
		||||
 | 
			
		||||
        control = torch.tensor(0, dtype=torch.int)
 | 
			
		||||
        dist.broadcast(control, src=0)
 | 
			
		||||
        dist.broadcast(self.forward_signal, src=0, async_op=True)
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float16)
 | 
			
		||||
        dist.send(hidden_states, dst=1)
 | 
			
		||||
        past_key_value.expand(self.transpose_value_cache)
 | 
			
		||||
        dist.recv(hidden_states, src=self.world_size - 1)
 | 
			
		||||
        t1 = time.perf_counter()
 | 
			
		||||
        return hidden_states, past_key_value
 | 
			
		||||
 | 
			
		||||
    def shutdown(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -889,7 +875,6 @@ def gen_minicpm_fused_model_forward(prefill_runner, decode_runner):
 | 
			
		|||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
        t0 = time.perf_counter()
 | 
			
		||||
        output_attentions = (
 | 
			
		||||
            output_attentions if output_attentions is not None
 | 
			
		||||
            else self.config.output_attentions
 | 
			
		||||
| 
						 | 
				
			
			@ -978,7 +963,6 @@ def gen_minicpm_fused_model_forward(prefill_runner, decode_runner):
 | 
			
		|||
                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
 | 
			
		||||
                if v is not None
 | 
			
		||||
            )
 | 
			
		||||
        t1 = time.perf_counter()
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -412,11 +412,6 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        return outputs
 | 
			
		||||
 | 
			
		||||
    def post_forward(self, past_key_value, new_keys, new_values):
 | 
			
		||||
        key_value_states = []
 | 
			
		||||
        for i in range(self.intra_stages):
 | 
			
		||||
            for j in range(1, len(self.backend_decoders[i].torch_out)):
 | 
			
		||||
                key_value_states.append(self.backend_decoders[i].torch_out[j])
 | 
			
		||||
 | 
			
		||||
        cache_kwargs = {
 | 
			
		||||
            "max_seq_len": self.max_seq_len,
 | 
			
		||||
            "transpose": self.transpose_value,
 | 
			
		||||
| 
						 | 
				
			
			@ -555,7 +550,6 @@ def run_decode(
 | 
			
		|||
    head_dim = model.model.layers[layer_start].self_attn.head_dim
 | 
			
		||||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    intermediate_size = model.config.intermediate_size
 | 
			
		||||
    deocderlayers = []
 | 
			
		||||
    layer_weights = []
 | 
			
		||||
    input_layer_norm_weights = []
 | 
			
		||||
    post_attn_layernorm_weights = []
 | 
			
		||||
| 
						 | 
				
			
			@ -640,7 +634,6 @@ def run_decode(
 | 
			
		|||
            elif control.item() == -1:
 | 
			
		||||
                past_key_values = input_queue.get()
 | 
			
		||||
            else:
 | 
			
		||||
                t0 = time.perf_counter()
 | 
			
		||||
                past_seen_tokens = past_key_values.get_seq_length()
 | 
			
		||||
                attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64)
 | 
			
		||||
                position_ids = torch.arange(
 | 
			
		||||
| 
						 | 
				
			
			@ -669,7 +662,6 @@ def run_decode(
 | 
			
		|||
                )
 | 
			
		||||
                padded_causal_mask[:, :, :, -1] = 0.0
 | 
			
		||||
                dist.recv(hidden_states, src=rank - 1)
 | 
			
		||||
                t1 = time.perf_counter()
 | 
			
		||||
                layer_outputs = multi_decoder(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask=padded_causal_mask,
 | 
			
		||||
| 
						 | 
				
			
			@ -678,11 +670,8 @@ def run_decode(
 | 
			
		|||
                    output_attentions=False,
 | 
			
		||||
                    use_cache=True,
 | 
			
		||||
                )
 | 
			
		||||
                t2 = time.perf_counter()
 | 
			
		||||
                hidden_states = layer_outputs[0]
 | 
			
		||||
                t3 = time.perf_counter()
 | 
			
		||||
                dist.send(hidden_states, dst=(rank + 1) % world_size)
 | 
			
		||||
                t4 = time.perf_counter()
 | 
			
		||||
                past_key_values = layer_outputs[1]
 | 
			
		||||
                new_keys = layer_outputs[2]
 | 
			
		||||
                new_values = layer_outputs[3]
 | 
			
		||||
| 
						 | 
				
			
			@ -757,22 +746,17 @@ class DecodeRunner:
 | 
			
		|||
        use_cache: bool = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        t0 = time.perf_counter()
 | 
			
		||||
 | 
			
		||||
        if self.cache_past_key_value != past_key_value:
 | 
			
		||||
            control = torch.tensor(-1, dtype=torch.int)
 | 
			
		||||
            dist.broadcast(control, src=0)
 | 
			
		||||
            for i in range(len(self.decoder_processes)):
 | 
			
		||||
                self.input_queues[i].put(past_key_value)
 | 
			
		||||
 | 
			
		||||
        t0 = time.perf_counter()
 | 
			
		||||
        dist.broadcast(self.forward_signal, src=0, async_op=True)
 | 
			
		||||
        t1 = time.perf_counter()
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float16)
 | 
			
		||||
        dist.send(hidden_states, dst=1)
 | 
			
		||||
        past_key_value.expand(self.transpose_value_cache)
 | 
			
		||||
        dist.recv(hidden_states, src=self.world_size - 1)
 | 
			
		||||
        t2 = time.perf_counter()
 | 
			
		||||
        return hidden_states, past_key_value
 | 
			
		||||
 | 
			
		||||
    def shutdown(self):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue