fix code geex (#12261)
This commit is contained in:
parent
f3a2b20e6b
commit
39c9d1de52
1 changed files with 3 additions and 2 deletions
|
|
@ -183,7 +183,7 @@ def chatglm2_encoder_forward(
|
||||||
if not kv_caches and not use_compress_kv:
|
if not kv_caches and not use_compress_kv:
|
||||||
kv_caches = [None for _ in range(self.num_layers)]
|
kv_caches = [None for _ in range(self.num_layers)]
|
||||||
presents = () if use_cache else None
|
presents = () if use_cache else None
|
||||||
if self.gradient_checkpointing and self.training:
|
if hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training:
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
all_self_attentions = None
|
all_self_attentions = None
|
||||||
|
|
@ -193,7 +193,8 @@ def chatglm2_encoder_forward(
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer = self._get_layer(index)
|
layer = self._get_layer(index)
|
||||||
if self.gradient_checkpointing and self.training:
|
if hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing \
|
||||||
|
and self.training:
|
||||||
layer_ret = torch.utils.checkpoint.checkpoint(
|
layer_ret = torch.utils.checkpoint.checkpoint(
|
||||||
layer,
|
layer,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue