Support XPU DDP training and autocast for LowBitMatmul (#9167)

* support autocast in low bit matmul

* Support XPU DDP training

* fix  amp
This commit is contained in:
Yang Wang 2023-10-17 11:47:19 +08:00 committed by GitHub
parent 77afb8796b
commit 7160afd4d1
3 changed files with 357 additions and 4 deletions

View file

@ -49,6 +49,7 @@ import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
from operator import mul
from functools import reduce
from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
T = TypeVar("T", bound="torch.nn.Module")
@ -288,7 +289,10 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor,
class MatMulLowBit(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, A, weight, input_seq_size):
if torch.xpu.is_autocast_xpu_enabled():
A = A.to(torch.xpu.get_autocast_xpu_dtype())
ctx.is_empty = False
import linear_q4_0
result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size)
@ -299,6 +303,7 @@ class MatMulLowBit(torch.autograd.Function):
return result
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
import linear_q4_0
if ctx.is_empty:
@ -308,7 +313,9 @@ class MatMulLowBit(torch.autograd.Function):
A, weight = ctx.tensors
grad_A, grad_weight = None, None
if req_gradA:
dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype).to(grad_output.dtype)
if torch.xpu.is_autocast_xpu_enabled():
grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype())
dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype)
grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape))
return grad_A, grad_weight, None
@ -378,8 +385,7 @@ class LowBitLinear(nn.Linear):
result = result.view(new_shape)
if self.bias is not None:
result += self.bias
return result.to(x.dtype)
return result
class FP16Linear(nn.Linear):

View file

@ -30,7 +30,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/training_args.py
#
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from bigdl.llm.transformers.low_bit_linear import LowBitLinear
@ -193,6 +209,147 @@ class PeftModel:
def patch_prepare_ipex(self, *args):
return tuple(args)
from transformers.utils import (
requires_backends,
is_sagemaker_mp_enabled,
is_accelerate_available,
is_torch_xpu_available,
is_sagemaker_dp_enabled,
is_torch_tpu_available,
is_torch_npu_available)
from transformers.utils.generic import strtobool
from transformers.utils import cached_property
from transformers.training_args import logger, ParallelMode, DistributedType
import torch
import torch.distributed as dist
import os
import warnings
from datetime import timedelta
if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
from accelerate.utils import DistributedType
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
smp.init()
@cached_property
def _setup_devices(self) -> "torch.device":
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
if not is_sagemaker_mp_enabled():
if not is_accelerate_available(min_version="0.20.1"):
invalidInputError(
False,
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: "
"Please run `pip install transformers[torch]` or `pip install accelerate -U`"
)
AcceleratorState._reset_state(reset_partial_state=True)
self.distributed_state = None
if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ:
os.environ["ACCELERATE_USE_IPEX"] = "false"
if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")):
self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend)
self._n_gpu = 0
elif is_sagemaker_mp_enabled():
local_rank = smp.local_rank()
device = torch.device("cuda", local_rank)
self._n_gpu = 1
torch.cuda.set_device(device)
elif is_torch_xpu_available() and "ACCELERATE_USE_XPU" not in os.environ:
os.environ["ACCELERATE_USE_XPU"] = "true"
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
# device = torch.device("xpu:0")
device = self.distributed_state.device
self._n_gpu = 1
elif is_sagemaker_dp_enabled():
self.distributed_state = PartialState(_use_sagemaker_dp=True)
self._n_gpu = 1
elif self.deepspeed:
# Need to do similar for Accelerator init
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
del os.environ["ACCELERATE_USE_DEEPSPEED"]
self._n_gpu = 1
else:
self.distributed_state = PartialState(
backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout)
)
self._n_gpu = 1
if not is_sagemaker_mp_enabled():
device = self.distributed_state.device
self.local_rank = self.distributed_state.local_process_index
if dist.is_available() and dist.is_initialized() and \
self.parallel_mode != ParallelMode.DISTRIBUTED:
logger.warning(
"torch.distributed process group is initialized, "
"but parallel_mode != ParallelMode.DISTRIBUTED. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
)
if is_torch_tpu_available():
device = self.distributed_state.device
self._n_gpu = 0
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
# Already set _n_gpu
pass
elif self.distributed_state.distributed_type == DistributedType.MULTI_XPU:
if "ACCELERATE_USE_XPU" not in os.environ:
os.environ["ACCELERATE_USE_XPU"] = "true"
# self._n_gpu = torch.xpu.device_count()
# device = torch.device("xpu:0")
# torch.xpu.set_device(device)
elif self.distributed_state.distributed_type == DistributedType.NO:
if self.use_mps_device:
warnings.warn(
"`use_mps_device` is deprecated and will be removed in"
" version 5.0 of 🤗 Transformers."
"`mps` device will be used by default if available similar"
" to the way `cuda` device is used."
"Therefore, no action from user is required. "
)
if device.type != "mps":
invalidInputError(False,
("Either you do not have an MPS-enabled device"
" on this machine or MacOS"
" version is not 12.3+ "
"or current PyTorch install was not built with MPS enabled."))
if device.type == "mps":
self._n_gpu = 1
elif self.use_cpu:
device = torch.device("cpu")
self._n_gpu = 0
elif is_torch_xpu_available():
device = torch.device("xpu:0")
torch.xpu.set_device(device)
self._n_gpu = 1
elif is_torch_npu_available():
device = torch.device("npu:0")
torch.npu.set_device(device)
self._n_gpu = 1
else:
# if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
# Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
# trigger an error that a device index is missing. Index 0 takes into account the
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
# will use the first GPU in that env, i.e. GPU#1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Sometimes the line in the postinit has not been run before we end up here,
# so just checking we're not at
# the default value.
self._n_gpu = torch.cuda.device_count()
if device.type == "cuda":
torch.cuda.set_device(device)
return device
# workaround a IPEX bug that prevents resume training in bf16
from accelerate import Accelerator
Accelerator._prepare_ipex = patch_prepare_ipex
# patch transformer for xpu DDP traing
from transformers import TrainingArguments
TrainingArguments._setup_devices = _setup_devices

View file

@ -0,0 +1,190 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/pytorch/pytorch/blob/v2.1.0/torch/cuda/amp/autocast_mode.py
#
# From PyTorch:
# Copyright (c) 2016- Facebook, Inc (Adam Paszke)
# Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
# Copyright (c) 2011-2013 NYU (Clement Farabet)
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert,
# Leon Bottou, Iain Melvin, Jason Weston)
# Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
# From Caffe2:
# Copyright (c) 2016-present, Facebook Inc. All rights reserved.
# All contributions by Facebook:
# Copyright (c) 2016 Facebook Inc.
# All contributions by Google:
# Copyright (c) 2015 Google Inc.
# All rights reserved.
# All contributions by Yangqing Jia:
# Copyright (c) 2015 Yangqing Jia
# All rights reserved.
# All contributions by Kakao Brain:
# Copyright 2019-2020 Kakao Brain
# All contributions by Cruise LLC:
# Copyright (c) 2022 Cruise LLC.
# All rights reserved.
# All contributions from Caffe:
# Copyright(c) 2013, 2014, 2015, the respective contributors
# All rights reserved.
# All other contributions:
# Copyright(c) 2015, 2016 the respective contributors
# All rights reserved.
# Caffe2 uses a copyright model similar to Caffe: each contributor holds
# copyright over their contributions to Caffe2. The project versioning records
# all such contribution and copyright details. If a contributor wants to further
# mark their specific copyright on a particular contribution, they should
# indicate their copyright solely in the commit message of the change when it is
# committed.
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
# and IDIAP Research Institute nor the names of its contributors may be
# used to endorse or promote products derived from this software without
# specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
import collections
import functools
import torch
try:
import numpy as np
HAS_NUMPY = True
except ModuleNotFoundError:
np = None # type: ignore[assignment]
from typing import Any
def _cast(value, dtype):
if isinstance(value, torch.Tensor):
is_eligible = (
value.is_floating_point()
and value.is_xpu
and (value.dtype is not torch.float64)
)
return value.to(dtype) if is_eligible else value
elif isinstance(value, (str, bytes)):
return value
elif HAS_NUMPY and isinstance(value, np.ndarray):
return value
elif isinstance(value, collections.abc.Mapping):
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
elif isinstance(value, collections.abc.Iterable):
iterable = (_cast(v, dtype) for v in value)
if isinstance(value, (list, tuple)):
return type(value)(iterable)
else:
return iterable
else:
return value
def custom_fwd(fwd=None, *, cast_inputs=None):
"""
Helper decorator for ``forward`` methods of custom autograd functions (subclasses of
:class:`torch.autograd.Function`). See the :ref:`example page<amp-custom-examples>`
for more detail.
Args:
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
when ``forward`` runs in an autocast-enabled region, casts incoming
floating-point CUDA Tensors to the target dtype (non-floating-point Tensors
are not affected),
then executes ``forward`` with autocast disabled.
If ``None``, ``forward``'s internal ops execute with the current autocast state.
.. note::
If the decorated ``forward`` is called outside an autocast-enabled region,
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
"""
if fwd is None:
return functools.partial(custom_fwd, cast_inputs=cast_inputs)
@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
args[0]._dtype = torch.xpu.get_autocast_xpu_dtype()
if cast_inputs is None:
args[0]._fwd_used_autocast = torch.xpu.is_autocast_xpu_enabled()
return fwd(*args, **kwargs)
else:
autocast_context = torch.xpu.is_autocast_xpu_enabled()
args[0]._fwd_used_autocast = False
if autocast_context:
with torch.xpu.autocast(enabled=False):
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
else:
return fwd(*args, **kwargs)
return decorate_fwd
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
# cast_inputs supplied to custom_fwd.
def custom_bwd(bwd):
"""
Helper decorator for backward methods of custom autograd functions (subclasses of
:class:`torch.autograd.Function`).
Ensures that ``backward`` executes with the same autocast state as ``forward``.
See the :ref:`example page<amp-custom-examples>` for more detail.
"""
@functools.wraps(bwd)
def decorate_bwd(*args, **kwargs):
with torch.xpu.autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
return bwd(*args, **kwargs)
return decorate_bwd