Support XPU DDP training and autocast for LowBitMatmul (#9167)
* support autocast in low bit matmul * Support XPU DDP training * fix amp
This commit is contained in:
parent
77afb8796b
commit
7160afd4d1
3 changed files with 357 additions and 4 deletions
|
|
@ -49,6 +49,7 @@ import torch.nn.functional as F
|
||||||
from torch import Tensor, device, dtype, nn
|
from torch import Tensor, device, dtype, nn
|
||||||
from operator import mul
|
from operator import mul
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
|
||||||
|
|
||||||
T = TypeVar("T", bound="torch.nn.Module")
|
T = TypeVar("T", bound="torch.nn.Module")
|
||||||
|
|
||||||
|
|
@ -288,7 +289,10 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor,
|
||||||
class MatMulLowBit(torch.autograd.Function):
|
class MatMulLowBit(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd
|
||||||
def forward(ctx, A, weight, input_seq_size):
|
def forward(ctx, A, weight, input_seq_size):
|
||||||
|
if torch.xpu.is_autocast_xpu_enabled():
|
||||||
|
A = A.to(torch.xpu.get_autocast_xpu_dtype())
|
||||||
ctx.is_empty = False
|
ctx.is_empty = False
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size)
|
result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size)
|
||||||
|
|
@ -299,6 +303,7 @@ class MatMulLowBit(torch.autograd.Function):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
if ctx.is_empty:
|
if ctx.is_empty:
|
||||||
|
|
@ -308,7 +313,9 @@ class MatMulLowBit(torch.autograd.Function):
|
||||||
A, weight = ctx.tensors
|
A, weight = ctx.tensors
|
||||||
grad_A, grad_weight = None, None
|
grad_A, grad_weight = None, None
|
||||||
if req_gradA:
|
if req_gradA:
|
||||||
dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype).to(grad_output.dtype)
|
if torch.xpu.is_autocast_xpu_enabled():
|
||||||
|
grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype())
|
||||||
|
dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype)
|
||||||
grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape))
|
grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape))
|
||||||
|
|
||||||
return grad_A, grad_weight, None
|
return grad_A, grad_weight, None
|
||||||
|
|
@ -378,8 +385,7 @@ class LowBitLinear(nn.Linear):
|
||||||
result = result.view(new_shape)
|
result = result.view(new_shape)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
result += self.bias
|
result += self.bias
|
||||||
|
return result
|
||||||
return result.to(x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class FP16Linear(nn.Linear):
|
class FP16Linear(nn.Linear):
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,23 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Some parts of this file is adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/training_args.py
|
||||||
|
#
|
||||||
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear
|
from bigdl.llm.transformers.low_bit_linear import LowBitLinear
|
||||||
|
|
@ -193,6 +209,147 @@ class PeftModel:
|
||||||
def patch_prepare_ipex(self, *args):
|
def patch_prepare_ipex(self, *args):
|
||||||
return tuple(args)
|
return tuple(args)
|
||||||
|
|
||||||
|
|
||||||
|
from transformers.utils import (
|
||||||
|
requires_backends,
|
||||||
|
is_sagemaker_mp_enabled,
|
||||||
|
is_accelerate_available,
|
||||||
|
is_torch_xpu_available,
|
||||||
|
is_sagemaker_dp_enabled,
|
||||||
|
is_torch_tpu_available,
|
||||||
|
is_torch_npu_available)
|
||||||
|
from transformers.utils.generic import strtobool
|
||||||
|
from transformers.utils import cached_property
|
||||||
|
from transformers.training_args import logger, ParallelMode, DistributedType
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate.state import AcceleratorState, PartialState
|
||||||
|
from accelerate.utils import DistributedType
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
smp.init()
|
||||||
|
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _setup_devices(self) -> "torch.device":
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
logger.info("PyTorch: setting up devices")
|
||||||
|
if not is_sagemaker_mp_enabled():
|
||||||
|
if not is_accelerate_available(min_version="0.20.1"):
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: "
|
||||||
|
"Please run `pip install transformers[torch]` or `pip install accelerate -U`"
|
||||||
|
)
|
||||||
|
AcceleratorState._reset_state(reset_partial_state=True)
|
||||||
|
self.distributed_state = None
|
||||||
|
if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ:
|
||||||
|
os.environ["ACCELERATE_USE_IPEX"] = "false"
|
||||||
|
if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")):
|
||||||
|
self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend)
|
||||||
|
self._n_gpu = 0
|
||||||
|
elif is_sagemaker_mp_enabled():
|
||||||
|
local_rank = smp.local_rank()
|
||||||
|
device = torch.device("cuda", local_rank)
|
||||||
|
self._n_gpu = 1
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
elif is_torch_xpu_available() and "ACCELERATE_USE_XPU" not in os.environ:
|
||||||
|
os.environ["ACCELERATE_USE_XPU"] = "true"
|
||||||
|
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
|
||||||
|
# device = torch.device("xpu:0")
|
||||||
|
device = self.distributed_state.device
|
||||||
|
self._n_gpu = 1
|
||||||
|
elif is_sagemaker_dp_enabled():
|
||||||
|
self.distributed_state = PartialState(_use_sagemaker_dp=True)
|
||||||
|
self._n_gpu = 1
|
||||||
|
elif self.deepspeed:
|
||||||
|
# Need to do similar for Accelerator init
|
||||||
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
|
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
|
||||||
|
del os.environ["ACCELERATE_USE_DEEPSPEED"]
|
||||||
|
self._n_gpu = 1
|
||||||
|
else:
|
||||||
|
self.distributed_state = PartialState(
|
||||||
|
backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout)
|
||||||
|
)
|
||||||
|
self._n_gpu = 1
|
||||||
|
if not is_sagemaker_mp_enabled():
|
||||||
|
device = self.distributed_state.device
|
||||||
|
self.local_rank = self.distributed_state.local_process_index
|
||||||
|
if dist.is_available() and dist.is_initialized() and \
|
||||||
|
self.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||||
|
logger.warning(
|
||||||
|
"torch.distributed process group is initialized, "
|
||||||
|
"but parallel_mode != ParallelMode.DISTRIBUTED. "
|
||||||
|
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
|
||||||
|
)
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
device = self.distributed_state.device
|
||||||
|
self._n_gpu = 0
|
||||||
|
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
|
||||||
|
# Already set _n_gpu
|
||||||
|
pass
|
||||||
|
elif self.distributed_state.distributed_type == DistributedType.MULTI_XPU:
|
||||||
|
if "ACCELERATE_USE_XPU" not in os.environ:
|
||||||
|
os.environ["ACCELERATE_USE_XPU"] = "true"
|
||||||
|
# self._n_gpu = torch.xpu.device_count()
|
||||||
|
# device = torch.device("xpu:0")
|
||||||
|
# torch.xpu.set_device(device)
|
||||||
|
elif self.distributed_state.distributed_type == DistributedType.NO:
|
||||||
|
if self.use_mps_device:
|
||||||
|
warnings.warn(
|
||||||
|
"`use_mps_device` is deprecated and will be removed in"
|
||||||
|
" version 5.0 of 🤗 Transformers."
|
||||||
|
"`mps` device will be used by default if available similar"
|
||||||
|
" to the way `cuda` device is used."
|
||||||
|
"Therefore, no action from user is required. "
|
||||||
|
)
|
||||||
|
if device.type != "mps":
|
||||||
|
invalidInputError(False,
|
||||||
|
("Either you do not have an MPS-enabled device"
|
||||||
|
" on this machine or MacOS"
|
||||||
|
" version is not 12.3+ "
|
||||||
|
"or current PyTorch install was not built with MPS enabled."))
|
||||||
|
if device.type == "mps":
|
||||||
|
self._n_gpu = 1
|
||||||
|
elif self.use_cpu:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
self._n_gpu = 0
|
||||||
|
elif is_torch_xpu_available():
|
||||||
|
device = torch.device("xpu:0")
|
||||||
|
torch.xpu.set_device(device)
|
||||||
|
self._n_gpu = 1
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = torch.device("npu:0")
|
||||||
|
torch.npu.set_device(device)
|
||||||
|
self._n_gpu = 1
|
||||||
|
else:
|
||||||
|
# if n_gpu is > 1 we'll use nn.DataParallel.
|
||||||
|
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||||
|
# Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
|
||||||
|
# trigger an error that a device index is missing. Index 0 takes into account the
|
||||||
|
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
|
||||||
|
# will use the first GPU in that env, i.e. GPU#1
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
# Sometimes the line in the postinit has not been run before we end up here,
|
||||||
|
# so just checking we're not at
|
||||||
|
# the default value.
|
||||||
|
self._n_gpu = torch.cuda.device_count()
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
return device
|
||||||
|
|
||||||
# workaround a IPEX bug that prevents resume training in bf16
|
# workaround a IPEX bug that prevents resume training in bf16
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
Accelerator._prepare_ipex = patch_prepare_ipex
|
Accelerator._prepare_ipex = patch_prepare_ipex
|
||||||
|
|
||||||
|
# patch transformer for xpu DDP traing
|
||||||
|
from transformers import TrainingArguments
|
||||||
|
TrainingArguments._setup_devices = _setup_devices
|
||||||
|
|
|
||||||
190
python/llm/src/bigdl/llm/transformers/xpu_customize_fwd.py
Normal file
190
python/llm/src/bigdl/llm/transformers/xpu_customize_fwd.py
Normal file
|
|
@ -0,0 +1,190 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Some parts of this file is adapted from
|
||||||
|
# https://github.com/pytorch/pytorch/blob/v2.1.0/torch/cuda/amp/autocast_mode.py
|
||||||
|
#
|
||||||
|
# From PyTorch:
|
||||||
|
|
||||||
|
# Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
||||||
|
# Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
||||||
|
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
||||||
|
# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
||||||
|
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
||||||
|
# Copyright (c) 2011-2013 NYU (Clement Farabet)
|
||||||
|
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert,
|
||||||
|
# Leon Bottou, Iain Melvin, Jason Weston)
|
||||||
|
# Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
||||||
|
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
||||||
|
|
||||||
|
# From Caffe2:
|
||||||
|
|
||||||
|
# Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
||||||
|
|
||||||
|
# All contributions by Facebook:
|
||||||
|
# Copyright (c) 2016 Facebook Inc.
|
||||||
|
|
||||||
|
# All contributions by Google:
|
||||||
|
# Copyright (c) 2015 Google Inc.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# All contributions by Yangqing Jia:
|
||||||
|
# Copyright (c) 2015 Yangqing Jia
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# All contributions by Kakao Brain:
|
||||||
|
# Copyright 2019-2020 Kakao Brain
|
||||||
|
|
||||||
|
# All contributions by Cruise LLC:
|
||||||
|
# Copyright (c) 2022 Cruise LLC.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# All contributions from Caffe:
|
||||||
|
# Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# All other contributions:
|
||||||
|
# Copyright(c) 2015, 2016 the respective contributors
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
||||||
|
# copyright over their contributions to Caffe2. The project versioning records
|
||||||
|
# all such contribution and copyright details. If a contributor wants to further
|
||||||
|
# mark their specific copyright on a particular contribution, they should
|
||||||
|
# indicate their copyright solely in the commit message of the change when it is
|
||||||
|
# committed.
|
||||||
|
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
# 1. Redistributions of source code must retain the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
# 2. Redistributions in binary form must reproduce the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer in the
|
||||||
|
# documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
||||||
|
# and IDIAP Research Institute nor the names of its contributors may be
|
||||||
|
# used to endorse or promote products derived from this software without
|
||||||
|
# specific prior written permission.
|
||||||
|
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||||
|
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
# POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
import collections
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
HAS_NUMPY = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
np = None # type: ignore[assignment]
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def _cast(value, dtype):
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
is_eligible = (
|
||||||
|
value.is_floating_point()
|
||||||
|
and value.is_xpu
|
||||||
|
and (value.dtype is not torch.float64)
|
||||||
|
)
|
||||||
|
return value.to(dtype) if is_eligible else value
|
||||||
|
elif isinstance(value, (str, bytes)):
|
||||||
|
return value
|
||||||
|
elif HAS_NUMPY and isinstance(value, np.ndarray):
|
||||||
|
return value
|
||||||
|
elif isinstance(value, collections.abc.Mapping):
|
||||||
|
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
|
||||||
|
elif isinstance(value, collections.abc.Iterable):
|
||||||
|
iterable = (_cast(v, dtype) for v in value)
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
return type(value)(iterable)
|
||||||
|
else:
|
||||||
|
return iterable
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def custom_fwd(fwd=None, *, cast_inputs=None):
|
||||||
|
"""
|
||||||
|
Helper decorator for ``forward`` methods of custom autograd functions (subclasses of
|
||||||
|
:class:`torch.autograd.Function`). See the :ref:`example page<amp-custom-examples>`
|
||||||
|
for more detail.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
|
||||||
|
when ``forward`` runs in an autocast-enabled region, casts incoming
|
||||||
|
floating-point CUDA Tensors to the target dtype (non-floating-point Tensors
|
||||||
|
are not affected),
|
||||||
|
then executes ``forward`` with autocast disabled.
|
||||||
|
If ``None``, ``forward``'s internal ops execute with the current autocast state.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If the decorated ``forward`` is called outside an autocast-enabled region,
|
||||||
|
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
|
||||||
|
"""
|
||||||
|
if fwd is None:
|
||||||
|
return functools.partial(custom_fwd, cast_inputs=cast_inputs)
|
||||||
|
|
||||||
|
@functools.wraps(fwd)
|
||||||
|
def decorate_fwd(*args, **kwargs):
|
||||||
|
args[0]._dtype = torch.xpu.get_autocast_xpu_dtype()
|
||||||
|
if cast_inputs is None:
|
||||||
|
args[0]._fwd_used_autocast = torch.xpu.is_autocast_xpu_enabled()
|
||||||
|
return fwd(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
autocast_context = torch.xpu.is_autocast_xpu_enabled()
|
||||||
|
args[0]._fwd_used_autocast = False
|
||||||
|
if autocast_context:
|
||||||
|
with torch.xpu.autocast(enabled=False):
|
||||||
|
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
|
||||||
|
else:
|
||||||
|
return fwd(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorate_fwd
|
||||||
|
|
||||||
|
|
||||||
|
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
|
||||||
|
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
|
||||||
|
# cast_inputs supplied to custom_fwd.
|
||||||
|
def custom_bwd(bwd):
|
||||||
|
"""
|
||||||
|
Helper decorator for backward methods of custom autograd functions (subclasses of
|
||||||
|
:class:`torch.autograd.Function`).
|
||||||
|
Ensures that ``backward`` executes with the same autocast state as ``forward``.
|
||||||
|
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(bwd)
|
||||||
|
def decorate_bwd(*args, **kwargs):
|
||||||
|
with torch.xpu.autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
|
||||||
|
return bwd(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorate_bwd
|
||||||
Loading…
Reference in a new issue