[NPU] further fix of new_value_states (#12538)

This commit is contained in:
Ruonan Wang 2024-12-12 21:42:00 -08:00 committed by GitHub
parent fa261b8af1
commit 7cc01fdc86
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -213,6 +213,7 @@ class LLMBaseNNFactory(NNFactory):
value_states = new_value_states value_states = new_value_states
else: else:
value_states = self.transpose(value_states, [0, 2, 1, 3]) value_states = self.transpose(value_states, [0, 2, 1, 3])
new_value_states = value_states
query_states, key_states = self.apply_rotary_pos_emb( query_states, key_states = self.apply_rotary_pos_emb(
q=query_states, q=query_states,
@ -225,7 +226,6 @@ class LLMBaseNNFactory(NNFactory):
head_dim=head_dim, head_dim=head_dim,
) )
new_key_states = key_states new_key_states = key_states
new_value_states = value_states
if mode == "decode": if mode == "decode":
key_states = self.concat(past_key, key_states, axis=-2) key_states = self.concat(past_key, key_states, axis=-2)