From a0c73c26d838694ffa698f323dc1338fddba6fc5 Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Wed, 11 Sep 2024 00:10:35 -0700 Subject: [PATCH] clean NPU code (#12060) * clean code * remove time.perf_counter() --- .../transformers/npu_models/baichuan_mp.py | 23 +++---------------- .../transformers/npu_models/llama_mp.py | 17 -------------- .../transformers/npu_models/minicpm_mp.py | 22 +++--------------- .../transformers/npu_models/qwen2_mp.py | 16 ------------- 4 files changed, 6 insertions(+), 72 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index 25cb790d..fdfdc528 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -415,11 +415,6 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module): return outputs def post_forward(self, past_key_value, new_keys, new_values): - key_value_states = [] - for i in range(self.intra_stages): - for j in range(1, len(self.backend_decoders[i].torch_out)): - key_value_states.append(self.backend_decoders[i].torch_out[j]) - cache_kwargs = { "max_seq_len": self.max_seq_len, "transpose": self.transpose_value, @@ -556,7 +551,6 @@ def run_decode( head_dim = model.model.layers[layer_start].self_attn.head_dim rms_norm_eps = model.config.rms_norm_eps intermediate_size = model.config.intermediate_size - deocderlayers = [] layer_weights = [] input_layer_norm_weights = [] post_attn_layernorm_weights = [] @@ -610,13 +604,12 @@ def run_decode( with torch.inference_mode(): while True: - dist.broadcast(control, src=0) + dist.broadcast(control, src=0, async_op=False) if control.item() == -2: break elif control.item() == -1: past_key_values = input_queue.get() else: - t0 = time.perf_counter() past_key_values_length = past_key_values.get_seq_length() seq_length_with_past = 1 + past_key_values_length position_ids = torch.arange( @@ -636,7 +629,6 @@ def run_decode( ) padded_causal_mask[:, :, :, -1] = 0.0 dist.recv(hidden_states, src=rank - 1) - t1 = time.perf_counter() layer_outputs = multi_decoder( hidden_states, attention_mask=padded_causal_mask, @@ -645,11 +637,8 @@ def run_decode( output_attentions=False, use_cache=True, ) - t2 = time.perf_counter() hidden_states = layer_outputs[0] - t3 = time.perf_counter() dist.send(hidden_states, dst=(rank + 1) % world_size) - t4 = time.perf_counter() past_key_values = layer_outputs[1] new_keys = layer_outputs[2] new_values = layer_outputs[3] @@ -674,6 +663,7 @@ class DecodeRunner: self.input_queues = [] self.output_queues = [] self.decoder_processes = [] + self.forward_signal = torch.tensor(0, dtype=torch.int) for rank in range(1, world_size): input_q = mp.Queue() @@ -721,21 +711,17 @@ class DecodeRunner: output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, ): - t0 = time.perf_counter() - if self.cache_past_key_value != past_key_value: control = torch.tensor(-1, dtype=torch.int) dist.broadcast(control, src=0) for i in range(len(self.decoder_processes)): self.input_queues[i].put(past_key_value) - control = torch.tensor(0, dtype=torch.int) - dist.broadcast(control, src=0) + dist.broadcast(self.forward_signal, src=0, async_op=True) hidden_states = hidden_states.to(torch.float16) dist.send(hidden_states, dst=1) past_key_value.expand(self.transpose_value_cache) dist.recv(hidden_states, src=self.world_size - 1) - t1 = time.perf_counter() return hidden_states, past_key_value def shutdown(self): @@ -918,7 +904,6 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - t0 = time.perf_counter() output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) @@ -1026,8 +1011,6 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner): for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) - t1 = time.perf_counter() - # print("fused model forward time: ", t1 - t0) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index f14dd859..b375c305 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -330,11 +330,6 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): return outputs def post_forward(self, past_key_value, new_keys, new_values, cache_position): - key_value_states = [] - for i in range(self.intra_stages): - for j in range(1, len(self.backend_decoders[i].torch_out)): - key_value_states.append(self.backend_decoders[i].torch_out[j]) - cache_kwargs = { "cache_position": cache_position, "max_seq_len": self.max_seq_len, @@ -474,7 +469,6 @@ def run_decode( head_dim = model.model.layers[layer_start].self_attn.head_dim rms_norm_eps = model.config.rms_norm_eps intermediate_size = model.config.intermediate_size - deocderlayers = [] layer_weights = [] input_layer_norm_weights = [] post_attn_layernorm_weights = [] @@ -536,7 +530,6 @@ def run_decode( elif control.item() == -1: past_key_values = input_queue.get() else: - t0 = time.perf_counter() past_seen_tokens = past_key_values.get_seq_length() attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) cache_position = torch.arange( @@ -555,7 +548,6 @@ def run_decode( ) padded_causal_mask[:, :, :, -1] = 0.0 dist.recv(hidden_states, src=rank - 1) - t1 = time.perf_counter() layer_outputs = multi_decoder( hidden_states, attention_mask=padded_causal_mask, @@ -565,11 +557,8 @@ def run_decode( use_cache=True, cache_position=cache_position, ) - t2 = time.perf_counter() hidden_states = layer_outputs[0] - t3 = time.perf_counter() dist.send(hidden_states, dst=(rank + 1) % world_size) - t4 = time.perf_counter() past_key_values = layer_outputs[1] new_keys = layer_outputs[2] new_values = layer_outputs[3] @@ -651,14 +640,11 @@ class DecodeRunner: dist.broadcast(control, src=0) for i in range(len(self.decoder_processes)): self.input_queues[i].put(past_key_value) - t0 = time.perf_counter() dist.broadcast(self.forward_signal, src=0, async_op=True) - t1 = time.perf_counter() hidden_states = hidden_states.to(torch.float16) dist.send(hidden_states, dst=1) past_key_value.expand(self.transpose_value_cache) dist.recv(hidden_states, src=self.world_size - 1) - t2 = time.perf_counter() return hidden_states, past_key_value def shutdown(self): @@ -847,7 +833,6 @@ def gen_llama_fused_model_forward(prefill_runner, decode_runner): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - t0 = time.perf_counter() output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) @@ -938,8 +923,6 @@ def gen_llama_fused_model_forward(prefill_runner, decode_runner): for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) - t1 = time.perf_counter() - # print("fused model forward time: ", t1 - t0) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, diff --git a/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py index 48271271..55299608 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py @@ -354,11 +354,6 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): return outputs def post_forward(self, past_key_value, new_keys, new_values): - key_value_states = [] - for i in range(self.intra_stages): - for j in range(1, len(self.backend_decoders[i].torch_out)): - key_value_states.append(self.backend_decoders[i].torch_out[j]) - cache_kwargs = { "max_seq_len": self.max_seq_len, "transpose": self.transpose_value, @@ -501,7 +496,6 @@ def run_decode( rms_norm_eps = model.config.rms_norm_eps intermediate_size = model.config.intermediate_size num_hidden_layers = model.config.num_hidden_layers - deocderlayers = [] layer_weights = [] input_layer_norm_weights = [] post_attn_layernorm_weights = [] @@ -559,13 +553,12 @@ def run_decode( with torch.inference_mode(): while True: - dist.broadcast(control, src=0) + dist.broadcast(control, src=0, async_op=False) if control.item() == -2: break elif control.item() == -1: past_key_values = input_queue.get() else: - t0 = time.perf_counter() past_seen_tokens = past_key_values.get_seq_length() attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) cache_position = torch.arange( @@ -589,7 +582,6 @@ def run_decode( ) padded_causal_mask[:, :, :, -1] = 0.0 dist.recv(hidden_states, src=rank - 1) - t1 = time.perf_counter() layer_outputs = multi_decoder( hidden_states, attention_mask=padded_causal_mask, @@ -598,11 +590,8 @@ def run_decode( output_attentions=False, use_cache=True, ) - t2 = time.perf_counter() hidden_states = layer_outputs[0] - t3 = time.perf_counter() dist.send(hidden_states, dst=(rank + 1) % world_size) - t4 = time.perf_counter() past_key_values = layer_outputs[1] new_keys = layer_outputs[2] new_values = layer_outputs[3] @@ -628,6 +617,7 @@ class DecodeRunner: self.input_queues = [] self.output_queues = [] self.decoder_processes = [] + self.forward_signal = torch.tensor(0, dtype=torch.int) for rank in range(1, world_size): input_q = mp.Queue() @@ -677,21 +667,17 @@ class DecodeRunner: use_cache: bool = False, **kwargs, ): - t0 = time.perf_counter() - if self.cache_past_key_value != past_key_value: control = torch.tensor(-1, dtype=torch.int) dist.broadcast(control, src=0) for i in range(len(self.decoder_processes)): self.input_queues[i].put(past_key_value) - control = torch.tensor(0, dtype=torch.int) - dist.broadcast(control, src=0) + dist.broadcast(self.forward_signal, src=0, async_op=True) hidden_states = hidden_states.to(torch.float16) dist.send(hidden_states, dst=1) past_key_value.expand(self.transpose_value_cache) dist.recv(hidden_states, src=self.world_size - 1) - t1 = time.perf_counter() return hidden_states, past_key_value def shutdown(self): @@ -889,7 +875,6 @@ def gen_minicpm_fused_model_forward(prefill_runner, decode_runner): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - t0 = time.perf_counter() output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions @@ -978,7 +963,6 @@ def gen_minicpm_fused_model_forward(prefill_runner, decode_runner): for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) - t1 = time.perf_counter() return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index ad639f67..b576c0d3 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -412,11 +412,6 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module): return outputs def post_forward(self, past_key_value, new_keys, new_values): - key_value_states = [] - for i in range(self.intra_stages): - for j in range(1, len(self.backend_decoders[i].torch_out)): - key_value_states.append(self.backend_decoders[i].torch_out[j]) - cache_kwargs = { "max_seq_len": self.max_seq_len, "transpose": self.transpose_value, @@ -555,7 +550,6 @@ def run_decode( head_dim = model.model.layers[layer_start].self_attn.head_dim rms_norm_eps = model.config.rms_norm_eps intermediate_size = model.config.intermediate_size - deocderlayers = [] layer_weights = [] input_layer_norm_weights = [] post_attn_layernorm_weights = [] @@ -640,7 +634,6 @@ def run_decode( elif control.item() == -1: past_key_values = input_queue.get() else: - t0 = time.perf_counter() past_seen_tokens = past_key_values.get_seq_length() attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) position_ids = torch.arange( @@ -669,7 +662,6 @@ def run_decode( ) padded_causal_mask[:, :, :, -1] = 0.0 dist.recv(hidden_states, src=rank - 1) - t1 = time.perf_counter() layer_outputs = multi_decoder( hidden_states, attention_mask=padded_causal_mask, @@ -678,11 +670,8 @@ def run_decode( output_attentions=False, use_cache=True, ) - t2 = time.perf_counter() hidden_states = layer_outputs[0] - t3 = time.perf_counter() dist.send(hidden_states, dst=(rank + 1) % world_size) - t4 = time.perf_counter() past_key_values = layer_outputs[1] new_keys = layer_outputs[2] new_values = layer_outputs[3] @@ -757,22 +746,17 @@ class DecodeRunner: use_cache: bool = False, **kwargs, ): - t0 = time.perf_counter() - if self.cache_past_key_value != past_key_value: control = torch.tensor(-1, dtype=torch.int) dist.broadcast(control, src=0) for i in range(len(self.decoder_processes)): self.input_queues[i].put(past_key_value) - t0 = time.perf_counter() dist.broadcast(self.forward_signal, src=0, async_op=True) - t1 = time.perf_counter() hidden_states = hidden_states.to(torch.float16) dist.send(hidden_states, dst=1) past_key_value.expand(self.transpose_value_cache) dist.recv(hidden_states, src=self.world_size - 1) - t2 = time.perf_counter() return hidden_states, past_key_value def shutdown(self):