[NPU] further fix of new_value_states (#12538)
This commit is contained in:
parent
fa261b8af1
commit
7cc01fdc86
1 changed files with 1 additions and 1 deletions
|
|
@ -213,6 +213,7 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
value_states = new_value_states
|
value_states = new_value_states
|
||||||
else:
|
else:
|
||||||
value_states = self.transpose(value_states, [0, 2, 1, 3])
|
value_states = self.transpose(value_states, [0, 2, 1, 3])
|
||||||
|
new_value_states = value_states
|
||||||
|
|
||||||
query_states, key_states = self.apply_rotary_pos_emb(
|
query_states, key_states = self.apply_rotary_pos_emb(
|
||||||
q=query_states,
|
q=query_states,
|
||||||
|
|
@ -225,7 +226,6 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
)
|
)
|
||||||
new_key_states = key_states
|
new_key_states = key_states
|
||||||
new_value_states = value_states
|
|
||||||
|
|
||||||
if mode == "decode":
|
if mode == "decode":
|
||||||
key_states = self.concat(past_key, key_states, axis=-2)
|
key_states = self.concat(past_key, key_states, axis=-2)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue