Support XPU DDP training and autocast for LowBitMatmul (#9167)

* support autocast in low bit matmul

* Support XPU DDP training

* fix  amp
This commit is contained in:
Yang Wang 2023-10-17 11:47:19 +08:00 committed by GitHub
parent 77afb8796b
commit 7160afd4d1
3 changed files with 357 additions and 4 deletions

View file

@ -49,6 +49,7 @@ import torch.nn.functional as F
from torch import Tensor, device, dtype, nn from torch import Tensor, device, dtype, nn
from operator import mul from operator import mul
from functools import reduce from functools import reduce
from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
T = TypeVar("T", bound="torch.nn.Module") T = TypeVar("T", bound="torch.nn.Module")
@ -288,7 +289,10 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor,
class MatMulLowBit(torch.autograd.Function): class MatMulLowBit(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd
def forward(ctx, A, weight, input_seq_size): def forward(ctx, A, weight, input_seq_size):
if torch.xpu.is_autocast_xpu_enabled():
A = A.to(torch.xpu.get_autocast_xpu_dtype())
ctx.is_empty = False ctx.is_empty = False
import linear_q4_0 import linear_q4_0
result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size) result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size)
@ -299,6 +303,7 @@ class MatMulLowBit(torch.autograd.Function):
return result return result
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
import linear_q4_0 import linear_q4_0
if ctx.is_empty: if ctx.is_empty:
@ -308,7 +313,9 @@ class MatMulLowBit(torch.autograd.Function):
A, weight = ctx.tensors A, weight = ctx.tensors
grad_A, grad_weight = None, None grad_A, grad_weight = None, None
if req_gradA: if req_gradA:
dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype).to(grad_output.dtype) if torch.xpu.is_autocast_xpu_enabled():
grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype())
dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype)
grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape)) grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape))
return grad_A, grad_weight, None return grad_A, grad_weight, None
@ -378,8 +385,7 @@ class LowBitLinear(nn.Linear):
result = result.view(new_shape) result = result.view(new_shape)
if self.bias is not None: if self.bias is not None:
result += self.bias result += self.bias
return result
return result.to(x.dtype)
class FP16Linear(nn.Linear): class FP16Linear(nn.Linear):

View file

@ -30,7 +30,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/training_args.py
#
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
from bigdl.llm.transformers.low_bit_linear import LowBitLinear from bigdl.llm.transformers.low_bit_linear import LowBitLinear
@ -193,6 +209,147 @@ class PeftModel:
def patch_prepare_ipex(self, *args): def patch_prepare_ipex(self, *args):
return tuple(args) return tuple(args)
from transformers.utils import (
requires_backends,
is_sagemaker_mp_enabled,
is_accelerate_available,
is_torch_xpu_available,
is_sagemaker_dp_enabled,
is_torch_tpu_available,
is_torch_npu_available)
from transformers.utils.generic import strtobool
from transformers.utils import cached_property
from transformers.training_args import logger, ParallelMode, DistributedType
import torch
import torch.distributed as dist
import os
import warnings
from datetime import timedelta
if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
from accelerate.utils import DistributedType
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
smp.init()
@cached_property
def _setup_devices(self) -> "torch.device":
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
if not is_sagemaker_mp_enabled():
if not is_accelerate_available(min_version="0.20.1"):
invalidInputError(
False,
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: "
"Please run `pip install transformers[torch]` or `pip install accelerate -U`"
)
AcceleratorState._reset_state(reset_partial_state=True)
self.distributed_state = None
if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ:
os.environ["ACCELERATE_USE_IPEX"] = "false"
if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")):
self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend)
self._n_gpu = 0
elif is_sagemaker_mp_enabled():
local_rank = smp.local_rank()
device = torch.device("cuda", local_rank)
self._n_gpu = 1
torch.cuda.set_device(device)
elif is_torch_xpu_available() and "ACCELERATE_USE_XPU" not in os.environ:
os.environ["ACCELERATE_USE_XPU"] = "true"
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
# device = torch.device("xpu:0")
device = self.distributed_state.device
self._n_gpu = 1
elif is_sagemaker_dp_enabled():
self.distributed_state = PartialState(_use_sagemaker_dp=True)
self._n_gpu = 1
elif self.deepspeed:
# Need to do similar for Accelerator init
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
del os.environ["ACCELERATE_USE_DEEPSPEED"]
self._n_gpu = 1
else:
self.distributed_state = PartialState(
backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout)
)
self._n_gpu = 1
if not is_sagemaker_mp_enabled():
device = self.distributed_state.device
self.local_rank = self.distributed_state.local_process_index
if dist.is_available() and dist.is_initialized() and \
self.parallel_mode != ParallelMode.DISTRIBUTED:
logger.warning(
"torch.distributed process group is initialized, "
"but parallel_mode != ParallelMode.DISTRIBUTED. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
)
if is_torch_tpu_available():
device = self.distributed_state.device
self._n_gpu = 0
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
# Already set _n_gpu
pass
elif self.distributed_state.distributed_type == DistributedType.MULTI_XPU:
if "ACCELERATE_USE_XPU" not in os.environ:
os.environ["ACCELERATE_USE_XPU"] = "true"
# self._n_gpu = torch.xpu.device_count()
# device = torch.device("xpu:0")
# torch.xpu.set_device(device)
elif self.distributed_state.distributed_type == DistributedType.NO:
if self.use_mps_device:
warnings.warn(
"`use_mps_device` is deprecated and will be removed in"
" version 5.0 of 🤗 Transformers."
"`mps` device will be used by default if available similar"
" to the way `cuda` device is used."
"Therefore, no action from user is required. "
)
if device.type != "mps":
invalidInputError(False,
("Either you do not have an MPS-enabled device"
" on this machine or MacOS"
" version is not 12.3+ "
"or current PyTorch install was not built with MPS enabled."))
if device.type == "mps":
self._n_gpu = 1
elif self.use_cpu:
device = torch.device("cpu")
self._n_gpu = 0
elif is_torch_xpu_available():
device = torch.device("xpu:0")
torch.xpu.set_device(device)
self._n_gpu = 1
elif is_torch_npu_available():
device = torch.device("npu:0")
torch.npu.set_device(device)
self._n_gpu = 1
else:
# if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
# Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
# trigger an error that a device index is missing. Index 0 takes into account the
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
# will use the first GPU in that env, i.e. GPU#1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Sometimes the line in the postinit has not been run before we end up here,
# so just checking we're not at
# the default value.
self._n_gpu = torch.cuda.device_count()
if device.type == "cuda":
torch.cuda.set_device(device)
return device
# workaround a IPEX bug that prevents resume training in bf16 # workaround a IPEX bug that prevents resume training in bf16
from accelerate import Accelerator from accelerate import Accelerator
Accelerator._prepare_ipex = patch_prepare_ipex Accelerator._prepare_ipex = patch_prepare_ipex
# patch transformer for xpu DDP traing
from transformers import TrainingArguments
TrainingArguments._setup_devices = _setup_devices

View file

@ -0,0 +1,190 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/pytorch/pytorch/blob/v2.1.0/torch/cuda/amp/autocast_mode.py
#
# From PyTorch:
# Copyright (c) 2016- Facebook, Inc (Adam Paszke)
# Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
# Copyright (c) 2011-2013 NYU (Clement Farabet)
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert,
# Leon Bottou, Iain Melvin, Jason Weston)
# Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
# From Caffe2:
# Copyright (c) 2016-present, Facebook Inc. All rights reserved.
# All contributions by Facebook:
# Copyright (c) 2016 Facebook Inc.
# All contributions by Google:
# Copyright (c) 2015 Google Inc.
# All rights reserved.
# All contributions by Yangqing Jia:
# Copyright (c) 2015 Yangqing Jia
# All rights reserved.
# All contributions by Kakao Brain:
# Copyright 2019-2020 Kakao Brain
# All contributions by Cruise LLC:
# Copyright (c) 2022 Cruise LLC.
# All rights reserved.
# All contributions from Caffe:
# Copyright(c) 2013, 2014, 2015, the respective contributors
# All rights reserved.
# All other contributions:
# Copyright(c) 2015, 2016 the respective contributors
# All rights reserved.
# Caffe2 uses a copyright model similar to Caffe: each contributor holds
# copyright over their contributions to Caffe2. The project versioning records
# all such contribution and copyright details. If a contributor wants to further
# mark their specific copyright on a particular contribution, they should
# indicate their copyright solely in the commit message of the change when it is
# committed.
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
# and IDIAP Research Institute nor the names of its contributors may be
# used to endorse or promote products derived from this software without
# specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
import collections
import functools
import torch
try:
import numpy as np
HAS_NUMPY = True
except ModuleNotFoundError:
np = None # type: ignore[assignment]
from typing import Any
def _cast(value, dtype):
if isinstance(value, torch.Tensor):
is_eligible = (
value.is_floating_point()
and value.is_xpu
and (value.dtype is not torch.float64)
)
return value.to(dtype) if is_eligible else value
elif isinstance(value, (str, bytes)):
return value
elif HAS_NUMPY and isinstance(value, np.ndarray):
return value
elif isinstance(value, collections.abc.Mapping):
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
elif isinstance(value, collections.abc.Iterable):
iterable = (_cast(v, dtype) for v in value)
if isinstance(value, (list, tuple)):
return type(value)(iterable)
else:
return iterable
else:
return value
def custom_fwd(fwd=None, *, cast_inputs=None):
"""
Helper decorator for ``forward`` methods of custom autograd functions (subclasses of
:class:`torch.autograd.Function`). See the :ref:`example page<amp-custom-examples>`
for more detail.
Args:
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
when ``forward`` runs in an autocast-enabled region, casts incoming
floating-point CUDA Tensors to the target dtype (non-floating-point Tensors
are not affected),
then executes ``forward`` with autocast disabled.
If ``None``, ``forward``'s internal ops execute with the current autocast state.
.. note::
If the decorated ``forward`` is called outside an autocast-enabled region,
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
"""
if fwd is None:
return functools.partial(custom_fwd, cast_inputs=cast_inputs)
@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
args[0]._dtype = torch.xpu.get_autocast_xpu_dtype()
if cast_inputs is None:
args[0]._fwd_used_autocast = torch.xpu.is_autocast_xpu_enabled()
return fwd(*args, **kwargs)
else:
autocast_context = torch.xpu.is_autocast_xpu_enabled()
args[0]._fwd_used_autocast = False
if autocast_context:
with torch.xpu.autocast(enabled=False):
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
else:
return fwd(*args, **kwargs)
return decorate_fwd
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
# cast_inputs supplied to custom_fwd.
def custom_bwd(bwd):
"""
Helper decorator for backward methods of custom autograd functions (subclasses of
:class:`torch.autograd.Function`).
Ensures that ``backward`` executes with the same autocast state as ``forward``.
See the :ref:`example page<amp-custom-examples>` for more detail.
"""
@functools.wraps(bwd)
def decorate_bwd(*args, **kwargs):
with torch.xpu.autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
return bwd(*args, **kwargs)
return decorate_bwd