[NPU] further fix of new_value_states (#12538)
This commit is contained in:
parent
fa261b8af1
commit
7cc01fdc86
1 changed files with 1 additions and 1 deletions
|
|
@ -213,6 +213,7 @@ class LLMBaseNNFactory(NNFactory):
|
|||
value_states = new_value_states
|
||||
else:
|
||||
value_states = self.transpose(value_states, [0, 2, 1, 3])
|
||||
new_value_states = value_states
|
||||
|
||||
query_states, key_states = self.apply_rotary_pos_emb(
|
||||
q=query_states,
|
||||
|
|
@ -225,7 +226,6 @@ class LLMBaseNNFactory(NNFactory):
|
|||
head_dim=head_dim,
|
||||
)
|
||||
new_key_states = key_states
|
||||
new_value_states = value_states
|
||||
|
||||
if mode == "decode":
|
||||
key_states = self.concat(past_key, key_states, axis=-2)
|
||||
|
|
|
|||
Loading…
Reference in a new issue