fix baichuan (#11606)
This commit is contained in:
		
							parent
							
								
									bfcdc35b04
								
							
						
					
					
						commit
						e5c0058c0e
					
				
					 1 changed files with 5 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -37,6 +37,7 @@ from torch.nn import functional as F
 | 
			
		|||
import importlib
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
from ipex_llm.transformers.npu_models.common import merge_linear
 | 
			
		||||
from ipex_llm.transformers.models.utils import update_past_key_value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_mlp(module: torch.nn.Module):
 | 
			
		||||
| 
						 | 
				
			
			@ -85,10 +86,10 @@ def baichuan_attention_fwd(
 | 
			
		|||
                                                    cos, sin, position_ids)
 | 
			
		||||
    # [bsz, nh, t, hd]
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        # reuse k, v, self_attention
 | 
			
		||||
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
 | 
			
		||||
        value_states = torch.cat([past_key_value[1], value_states], dim=2)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        past_key_value, key_states, value_states,
 | 
			
		||||
        kv_seq_len, False, "cpu"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue