[NPU] further fix of new_value_states (#12538)
				
					
				
			This commit is contained in:
		
							parent
							
								
									fa261b8af1
								
							
						
					
					
						commit
						7cc01fdc86
					
				
					 1 changed files with 1 additions and 1 deletions
				
			
		| 
						 | 
					@ -213,6 +213,7 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
                value_states = new_value_states
 | 
					                value_states = new_value_states
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            value_states = self.transpose(value_states, [0, 2, 1, 3])
 | 
					            value_states = self.transpose(value_states, [0, 2, 1, 3])
 | 
				
			||||||
 | 
					            new_value_states = value_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        query_states, key_states = self.apply_rotary_pos_emb(
 | 
					        query_states, key_states = self.apply_rotary_pos_emb(
 | 
				
			||||||
            q=query_states,
 | 
					            q=query_states,
 | 
				
			||||||
| 
						 | 
					@ -225,7 +226,6 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
            head_dim=head_dim,
 | 
					            head_dim=head_dim,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        new_key_states = key_states
 | 
					        new_key_states = key_states
 | 
				
			||||||
        new_value_states = value_states
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if mode == "decode":
 | 
					        if mode == "decode":
 | 
				
			||||||
            key_states = self.concat(past_key, key_states, axis=-2)
 | 
					            key_states = self.concat(past_key, key_states, axis=-2)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue