clean NPU code (#12060)

* clean code

* remove time.perf_counter()
This commit is contained in:
Ruonan Wang 2024-09-11 00:10:35 -07:00 committed by GitHub
parent c75f3dd874
commit a0c73c26d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 6 additions and 72 deletions

View file

@ -415,11 +415,6 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
return outputs return outputs
def post_forward(self, past_key_value, new_keys, new_values): def post_forward(self, past_key_value, new_keys, new_values):
key_value_states = []
for i in range(self.intra_stages):
for j in range(1, len(self.backend_decoders[i].torch_out)):
key_value_states.append(self.backend_decoders[i].torch_out[j])
cache_kwargs = { cache_kwargs = {
"max_seq_len": self.max_seq_len, "max_seq_len": self.max_seq_len,
"transpose": self.transpose_value, "transpose": self.transpose_value,
@ -556,7 +551,6 @@ def run_decode(
head_dim = model.model.layers[layer_start].self_attn.head_dim head_dim = model.model.layers[layer_start].self_attn.head_dim
rms_norm_eps = model.config.rms_norm_eps rms_norm_eps = model.config.rms_norm_eps
intermediate_size = model.config.intermediate_size intermediate_size = model.config.intermediate_size
deocderlayers = []
layer_weights = [] layer_weights = []
input_layer_norm_weights = [] input_layer_norm_weights = []
post_attn_layernorm_weights = [] post_attn_layernorm_weights = []
@ -610,13 +604,12 @@ def run_decode(
with torch.inference_mode(): with torch.inference_mode():
while True: while True:
dist.broadcast(control, src=0) dist.broadcast(control, src=0, async_op=False)
if control.item() == -2: if control.item() == -2:
break break
elif control.item() == -1: elif control.item() == -1:
past_key_values = input_queue.get() past_key_values = input_queue.get()
else: else:
t0 = time.perf_counter()
past_key_values_length = past_key_values.get_seq_length() past_key_values_length = past_key_values.get_seq_length()
seq_length_with_past = 1 + past_key_values_length seq_length_with_past = 1 + past_key_values_length
position_ids = torch.arange( position_ids = torch.arange(
@ -636,7 +629,6 @@ def run_decode(
) )
padded_causal_mask[:, :, :, -1] = 0.0 padded_causal_mask[:, :, :, -1] = 0.0
dist.recv(hidden_states, src=rank - 1) dist.recv(hidden_states, src=rank - 1)
t1 = time.perf_counter()
layer_outputs = multi_decoder( layer_outputs = multi_decoder(
hidden_states, hidden_states,
attention_mask=padded_causal_mask, attention_mask=padded_causal_mask,
@ -645,11 +637,8 @@ def run_decode(
output_attentions=False, output_attentions=False,
use_cache=True, use_cache=True,
) )
t2 = time.perf_counter()
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
t3 = time.perf_counter()
dist.send(hidden_states, dst=(rank + 1) % world_size) dist.send(hidden_states, dst=(rank + 1) % world_size)
t4 = time.perf_counter()
past_key_values = layer_outputs[1] past_key_values = layer_outputs[1]
new_keys = layer_outputs[2] new_keys = layer_outputs[2]
new_values = layer_outputs[3] new_values = layer_outputs[3]
@ -674,6 +663,7 @@ class DecodeRunner:
self.input_queues = [] self.input_queues = []
self.output_queues = [] self.output_queues = []
self.decoder_processes = [] self.decoder_processes = []
self.forward_signal = torch.tensor(0, dtype=torch.int)
for rank in range(1, world_size): for rank in range(1, world_size):
input_q = mp.Queue() input_q = mp.Queue()
@ -721,21 +711,17 @@ class DecodeRunner:
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
): ):
t0 = time.perf_counter()
if self.cache_past_key_value != past_key_value: if self.cache_past_key_value != past_key_value:
control = torch.tensor(-1, dtype=torch.int) control = torch.tensor(-1, dtype=torch.int)
dist.broadcast(control, src=0) dist.broadcast(control, src=0)
for i in range(len(self.decoder_processes)): for i in range(len(self.decoder_processes)):
self.input_queues[i].put(past_key_value) self.input_queues[i].put(past_key_value)
control = torch.tensor(0, dtype=torch.int) dist.broadcast(self.forward_signal, src=0, async_op=True)
dist.broadcast(control, src=0)
hidden_states = hidden_states.to(torch.float16) hidden_states = hidden_states.to(torch.float16)
dist.send(hidden_states, dst=1) dist.send(hidden_states, dst=1)
past_key_value.expand(self.transpose_value_cache) past_key_value.expand(self.transpose_value_cache)
dist.recv(hidden_states, src=self.world_size - 1) dist.recv(hidden_states, src=self.world_size - 1)
t1 = time.perf_counter()
return hidden_states, past_key_value return hidden_states, past_key_value
def shutdown(self): def shutdown(self):
@ -918,7 +904,6 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
t0 = time.perf_counter()
output_attentions = ( output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions output_attentions if output_attentions is not None else self.config.output_attentions
) )
@ -1026,8 +1011,6 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None if v is not None
) )
t1 = time.perf_counter()
# print("fused model forward time: ", t1 - t0)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_cache, past_key_values=next_cache,

View file

@ -330,11 +330,6 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
return outputs return outputs
def post_forward(self, past_key_value, new_keys, new_values, cache_position): def post_forward(self, past_key_value, new_keys, new_values, cache_position):
key_value_states = []
for i in range(self.intra_stages):
for j in range(1, len(self.backend_decoders[i].torch_out)):
key_value_states.append(self.backend_decoders[i].torch_out[j])
cache_kwargs = { cache_kwargs = {
"cache_position": cache_position, "cache_position": cache_position,
"max_seq_len": self.max_seq_len, "max_seq_len": self.max_seq_len,
@ -474,7 +469,6 @@ def run_decode(
head_dim = model.model.layers[layer_start].self_attn.head_dim head_dim = model.model.layers[layer_start].self_attn.head_dim
rms_norm_eps = model.config.rms_norm_eps rms_norm_eps = model.config.rms_norm_eps
intermediate_size = model.config.intermediate_size intermediate_size = model.config.intermediate_size
deocderlayers = []
layer_weights = [] layer_weights = []
input_layer_norm_weights = [] input_layer_norm_weights = []
post_attn_layernorm_weights = [] post_attn_layernorm_weights = []
@ -536,7 +530,6 @@ def run_decode(
elif control.item() == -1: elif control.item() == -1:
past_key_values = input_queue.get() past_key_values = input_queue.get()
else: else:
t0 = time.perf_counter()
past_seen_tokens = past_key_values.get_seq_length() past_seen_tokens = past_key_values.get_seq_length()
attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64)
cache_position = torch.arange( cache_position = torch.arange(
@ -555,7 +548,6 @@ def run_decode(
) )
padded_causal_mask[:, :, :, -1] = 0.0 padded_causal_mask[:, :, :, -1] = 0.0
dist.recv(hidden_states, src=rank - 1) dist.recv(hidden_states, src=rank - 1)
t1 = time.perf_counter()
layer_outputs = multi_decoder( layer_outputs = multi_decoder(
hidden_states, hidden_states,
attention_mask=padded_causal_mask, attention_mask=padded_causal_mask,
@ -565,11 +557,8 @@ def run_decode(
use_cache=True, use_cache=True,
cache_position=cache_position, cache_position=cache_position,
) )
t2 = time.perf_counter()
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
t3 = time.perf_counter()
dist.send(hidden_states, dst=(rank + 1) % world_size) dist.send(hidden_states, dst=(rank + 1) % world_size)
t4 = time.perf_counter()
past_key_values = layer_outputs[1] past_key_values = layer_outputs[1]
new_keys = layer_outputs[2] new_keys = layer_outputs[2]
new_values = layer_outputs[3] new_values = layer_outputs[3]
@ -651,14 +640,11 @@ class DecodeRunner:
dist.broadcast(control, src=0) dist.broadcast(control, src=0)
for i in range(len(self.decoder_processes)): for i in range(len(self.decoder_processes)):
self.input_queues[i].put(past_key_value) self.input_queues[i].put(past_key_value)
t0 = time.perf_counter()
dist.broadcast(self.forward_signal, src=0, async_op=True) dist.broadcast(self.forward_signal, src=0, async_op=True)
t1 = time.perf_counter()
hidden_states = hidden_states.to(torch.float16) hidden_states = hidden_states.to(torch.float16)
dist.send(hidden_states, dst=1) dist.send(hidden_states, dst=1)
past_key_value.expand(self.transpose_value_cache) past_key_value.expand(self.transpose_value_cache)
dist.recv(hidden_states, src=self.world_size - 1) dist.recv(hidden_states, src=self.world_size - 1)
t2 = time.perf_counter()
return hidden_states, past_key_value return hidden_states, past_key_value
def shutdown(self): def shutdown(self):
@ -847,7 +833,6 @@ def gen_llama_fused_model_forward(prefill_runner, decode_runner):
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
t0 = time.perf_counter()
output_attentions = ( output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions output_attentions if output_attentions is not None else self.config.output_attentions
) )
@ -938,8 +923,6 @@ def gen_llama_fused_model_forward(prefill_runner, decode_runner):
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None if v is not None
) )
t1 = time.perf_counter()
# print("fused model forward time: ", t1 - t0)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_cache, past_key_values=next_cache,

View file

@ -354,11 +354,6 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
return outputs return outputs
def post_forward(self, past_key_value, new_keys, new_values): def post_forward(self, past_key_value, new_keys, new_values):
key_value_states = []
for i in range(self.intra_stages):
for j in range(1, len(self.backend_decoders[i].torch_out)):
key_value_states.append(self.backend_decoders[i].torch_out[j])
cache_kwargs = { cache_kwargs = {
"max_seq_len": self.max_seq_len, "max_seq_len": self.max_seq_len,
"transpose": self.transpose_value, "transpose": self.transpose_value,
@ -501,7 +496,6 @@ def run_decode(
rms_norm_eps = model.config.rms_norm_eps rms_norm_eps = model.config.rms_norm_eps
intermediate_size = model.config.intermediate_size intermediate_size = model.config.intermediate_size
num_hidden_layers = model.config.num_hidden_layers num_hidden_layers = model.config.num_hidden_layers
deocderlayers = []
layer_weights = [] layer_weights = []
input_layer_norm_weights = [] input_layer_norm_weights = []
post_attn_layernorm_weights = [] post_attn_layernorm_weights = []
@ -559,13 +553,12 @@ def run_decode(
with torch.inference_mode(): with torch.inference_mode():
while True: while True:
dist.broadcast(control, src=0) dist.broadcast(control, src=0, async_op=False)
if control.item() == -2: if control.item() == -2:
break break
elif control.item() == -1: elif control.item() == -1:
past_key_values = input_queue.get() past_key_values = input_queue.get()
else: else:
t0 = time.perf_counter()
past_seen_tokens = past_key_values.get_seq_length() past_seen_tokens = past_key_values.get_seq_length()
attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64)
cache_position = torch.arange( cache_position = torch.arange(
@ -589,7 +582,6 @@ def run_decode(
) )
padded_causal_mask[:, :, :, -1] = 0.0 padded_causal_mask[:, :, :, -1] = 0.0
dist.recv(hidden_states, src=rank - 1) dist.recv(hidden_states, src=rank - 1)
t1 = time.perf_counter()
layer_outputs = multi_decoder( layer_outputs = multi_decoder(
hidden_states, hidden_states,
attention_mask=padded_causal_mask, attention_mask=padded_causal_mask,
@ -598,11 +590,8 @@ def run_decode(
output_attentions=False, output_attentions=False,
use_cache=True, use_cache=True,
) )
t2 = time.perf_counter()
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
t3 = time.perf_counter()
dist.send(hidden_states, dst=(rank + 1) % world_size) dist.send(hidden_states, dst=(rank + 1) % world_size)
t4 = time.perf_counter()
past_key_values = layer_outputs[1] past_key_values = layer_outputs[1]
new_keys = layer_outputs[2] new_keys = layer_outputs[2]
new_values = layer_outputs[3] new_values = layer_outputs[3]
@ -628,6 +617,7 @@ class DecodeRunner:
self.input_queues = [] self.input_queues = []
self.output_queues = [] self.output_queues = []
self.decoder_processes = [] self.decoder_processes = []
self.forward_signal = torch.tensor(0, dtype=torch.int)
for rank in range(1, world_size): for rank in range(1, world_size):
input_q = mp.Queue() input_q = mp.Queue()
@ -677,21 +667,17 @@ class DecodeRunner:
use_cache: bool = False, use_cache: bool = False,
**kwargs, **kwargs,
): ):
t0 = time.perf_counter()
if self.cache_past_key_value != past_key_value: if self.cache_past_key_value != past_key_value:
control = torch.tensor(-1, dtype=torch.int) control = torch.tensor(-1, dtype=torch.int)
dist.broadcast(control, src=0) dist.broadcast(control, src=0)
for i in range(len(self.decoder_processes)): for i in range(len(self.decoder_processes)):
self.input_queues[i].put(past_key_value) self.input_queues[i].put(past_key_value)
control = torch.tensor(0, dtype=torch.int) dist.broadcast(self.forward_signal, src=0, async_op=True)
dist.broadcast(control, src=0)
hidden_states = hidden_states.to(torch.float16) hidden_states = hidden_states.to(torch.float16)
dist.send(hidden_states, dst=1) dist.send(hidden_states, dst=1)
past_key_value.expand(self.transpose_value_cache) past_key_value.expand(self.transpose_value_cache)
dist.recv(hidden_states, src=self.world_size - 1) dist.recv(hidden_states, src=self.world_size - 1)
t1 = time.perf_counter()
return hidden_states, past_key_value return hidden_states, past_key_value
def shutdown(self): def shutdown(self):
@ -889,7 +875,6 @@ def gen_minicpm_fused_model_forward(prefill_runner, decode_runner):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
t0 = time.perf_counter()
output_attentions = ( output_attentions = (
output_attentions if output_attentions is not None output_attentions if output_attentions is not None
else self.config.output_attentions else self.config.output_attentions
@ -978,7 +963,6 @@ def gen_minicpm_fused_model_forward(prefill_runner, decode_runner):
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None if v is not None
) )
t1 = time.perf_counter()
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_cache, past_key_values=next_cache,

View file

@ -412,11 +412,6 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
return outputs return outputs
def post_forward(self, past_key_value, new_keys, new_values): def post_forward(self, past_key_value, new_keys, new_values):
key_value_states = []
for i in range(self.intra_stages):
for j in range(1, len(self.backend_decoders[i].torch_out)):
key_value_states.append(self.backend_decoders[i].torch_out[j])
cache_kwargs = { cache_kwargs = {
"max_seq_len": self.max_seq_len, "max_seq_len": self.max_seq_len,
"transpose": self.transpose_value, "transpose": self.transpose_value,
@ -555,7 +550,6 @@ def run_decode(
head_dim = model.model.layers[layer_start].self_attn.head_dim head_dim = model.model.layers[layer_start].self_attn.head_dim
rms_norm_eps = model.config.rms_norm_eps rms_norm_eps = model.config.rms_norm_eps
intermediate_size = model.config.intermediate_size intermediate_size = model.config.intermediate_size
deocderlayers = []
layer_weights = [] layer_weights = []
input_layer_norm_weights = [] input_layer_norm_weights = []
post_attn_layernorm_weights = [] post_attn_layernorm_weights = []
@ -640,7 +634,6 @@ def run_decode(
elif control.item() == -1: elif control.item() == -1:
past_key_values = input_queue.get() past_key_values = input_queue.get()
else: else:
t0 = time.perf_counter()
past_seen_tokens = past_key_values.get_seq_length() past_seen_tokens = past_key_values.get_seq_length()
attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64)
position_ids = torch.arange( position_ids = torch.arange(
@ -669,7 +662,6 @@ def run_decode(
) )
padded_causal_mask[:, :, :, -1] = 0.0 padded_causal_mask[:, :, :, -1] = 0.0
dist.recv(hidden_states, src=rank - 1) dist.recv(hidden_states, src=rank - 1)
t1 = time.perf_counter()
layer_outputs = multi_decoder( layer_outputs = multi_decoder(
hidden_states, hidden_states,
attention_mask=padded_causal_mask, attention_mask=padded_causal_mask,
@ -678,11 +670,8 @@ def run_decode(
output_attentions=False, output_attentions=False,
use_cache=True, use_cache=True,
) )
t2 = time.perf_counter()
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
t3 = time.perf_counter()
dist.send(hidden_states, dst=(rank + 1) % world_size) dist.send(hidden_states, dst=(rank + 1) % world_size)
t4 = time.perf_counter()
past_key_values = layer_outputs[1] past_key_values = layer_outputs[1]
new_keys = layer_outputs[2] new_keys = layer_outputs[2]
new_values = layer_outputs[3] new_values = layer_outputs[3]
@ -757,22 +746,17 @@ class DecodeRunner:
use_cache: bool = False, use_cache: bool = False,
**kwargs, **kwargs,
): ):
t0 = time.perf_counter()
if self.cache_past_key_value != past_key_value: if self.cache_past_key_value != past_key_value:
control = torch.tensor(-1, dtype=torch.int) control = torch.tensor(-1, dtype=torch.int)
dist.broadcast(control, src=0) dist.broadcast(control, src=0)
for i in range(len(self.decoder_processes)): for i in range(len(self.decoder_processes)):
self.input_queues[i].put(past_key_value) self.input_queues[i].put(past_key_value)
t0 = time.perf_counter()
dist.broadcast(self.forward_signal, src=0, async_op=True) dist.broadcast(self.forward_signal, src=0, async_op=True)
t1 = time.perf_counter()
hidden_states = hidden_states.to(torch.float16) hidden_states = hidden_states.to(torch.float16)
dist.send(hidden_states, dst=1) dist.send(hidden_states, dst=1)
past_key_value.expand(self.transpose_value_cache) past_key_value.expand(self.transpose_value_cache)
dist.recv(hidden_states, src=self.world_size - 1) dist.recv(hidden_states, src=self.world_size - 1)
t2 = time.perf_counter()
return hidden_states, past_key_value return hidden_states, past_key_value
def shutdown(self): def shutdown(self):