fix code geex (#12261)
This commit is contained in:
		
							parent
							
								
									f3a2b20e6b
								
							
						
					
					
						commit
						39c9d1de52
					
				
					 1 changed files with 3 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -183,7 +183,7 @@ def chatglm2_encoder_forward(
 | 
			
		|||
    if not kv_caches and not use_compress_kv:
 | 
			
		||||
        kv_caches = [None for _ in range(self.num_layers)]
 | 
			
		||||
    presents = () if use_cache else None
 | 
			
		||||
    if self.gradient_checkpointing and self.training:
 | 
			
		||||
    if hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training:
 | 
			
		||||
        use_cache = False
 | 
			
		||||
 | 
			
		||||
    all_self_attentions = None
 | 
			
		||||
| 
						 | 
				
			
			@ -193,7 +193,8 @@ def chatglm2_encoder_forward(
 | 
			
		|||
            all_hidden_states = all_hidden_states + (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        layer = self._get_layer(index)
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
        if hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing \
 | 
			
		||||
                and self.training:
 | 
			
		||||
            layer_ret = torch.utils.checkpoint.checkpoint(
 | 
			
		||||
                layer,
 | 
			
		||||
                hidden_states,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue