fix code geex (#12261)
This commit is contained in:
parent
f3a2b20e6b
commit
39c9d1de52
1 changed files with 3 additions and 2 deletions
|
|
@ -183,7 +183,7 @@ def chatglm2_encoder_forward(
|
|||
if not kv_caches and not use_compress_kv:
|
||||
kv_caches = [None for _ in range(self.num_layers)]
|
||||
presents = () if use_cache else None
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training:
|
||||
use_cache = False
|
||||
|
||||
all_self_attentions = None
|
||||
|
|
@ -193,7 +193,8 @@ def chatglm2_encoder_forward(
|
|||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer = self._get_layer(index)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing \
|
||||
and self.training:
|
||||
layer_ret = torch.utils.checkpoint.checkpoint(
|
||||
layer,
|
||||
hidden_states,
|
||||
|
|
|
|||
Loading…
Reference in a new issue