fix code geex (#12261)

This commit is contained in:
Xin Qiu 2024-10-24 14:34:01 +08:00 committed by GitHub
parent f3a2b20e6b
commit 39c9d1de52
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -183,7 +183,7 @@ def chatglm2_encoder_forward(
if not kv_caches and not use_compress_kv: if not kv_caches and not use_compress_kv:
kv_caches = [None for _ in range(self.num_layers)] kv_caches = [None for _ in range(self.num_layers)]
presents = () if use_cache else None presents = () if use_cache else None
if self.gradient_checkpointing and self.training: if hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training:
use_cache = False use_cache = False
all_self_attentions = None all_self_attentions = None
@ -193,7 +193,8 @@ def chatglm2_encoder_forward(
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer = self._get_layer(index) layer = self._get_layer(index)
if self.gradient_checkpointing and self.training: if hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing \
and self.training:
layer_ret = torch.utils.checkpoint.checkpoint( layer_ret = torch.utils.checkpoint.checkpoint(
layer, layer,
hidden_states, hidden_states,