fix baichuan (#11606)
This commit is contained in:
		
							parent
							
								
									bfcdc35b04
								
							
						
					
					
						commit
						e5c0058c0e
					
				
					 1 changed files with 5 additions and 4 deletions
				
			
		| 
						 | 
					@ -37,6 +37,7 @@ from torch.nn import functional as F
 | 
				
			||||||
import importlib
 | 
					import importlib
 | 
				
			||||||
from typing import Optional, Tuple
 | 
					from typing import Optional, Tuple
 | 
				
			||||||
from ipex_llm.transformers.npu_models.common import merge_linear
 | 
					from ipex_llm.transformers.npu_models.common import merge_linear
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import update_past_key_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def merge_mlp(module: torch.nn.Module):
 | 
					def merge_mlp(module: torch.nn.Module):
 | 
				
			||||||
| 
						 | 
					@ -85,10 +86,10 @@ def baichuan_attention_fwd(
 | 
				
			||||||
                                                    cos, sin, position_ids)
 | 
					                                                    cos, sin, position_ids)
 | 
				
			||||||
    # [bsz, nh, t, hd]
 | 
					    # [bsz, nh, t, hd]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if past_key_value is not None:
 | 
					    key_states, value_states = update_past_key_value(
 | 
				
			||||||
        # reuse k, v, self_attention
 | 
					        past_key_value, key_states, value_states,
 | 
				
			||||||
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
 | 
					        kv_seq_len, False, "cpu"
 | 
				
			||||||
        value_states = torch.cat([past_key_value[1], value_states], dim=2)
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
					    past_key_value = (key_states, value_states) if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue