diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index d9a0ffe2..4940b13e 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -49,6 +49,7 @@ import torch.nn.functional as F from torch import Tensor, device, dtype, nn from operator import mul from functools import reduce +from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd T = TypeVar("T", bound="torch.nn.Module") @@ -288,7 +289,10 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor, class MatMulLowBit(torch.autograd.Function): @staticmethod + @custom_fwd def forward(ctx, A, weight, input_seq_size): + if torch.xpu.is_autocast_xpu_enabled(): + A = A.to(torch.xpu.get_autocast_xpu_dtype()) ctx.is_empty = False import linear_q4_0 result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size) @@ -299,6 +303,7 @@ class MatMulLowBit(torch.autograd.Function): return result @staticmethod + @custom_bwd def backward(ctx, grad_output): import linear_q4_0 if ctx.is_empty: @@ -308,7 +313,9 @@ class MatMulLowBit(torch.autograd.Function): A, weight = ctx.tensors grad_A, grad_weight = None, None if req_gradA: - dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype).to(grad_output.dtype) + if torch.xpu.is_autocast_xpu_enabled(): + grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype()) + dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype) grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape)) return grad_A, grad_weight, None @@ -378,8 +385,7 @@ class LowBitLinear(nn.Linear): result = result.view(new_shape) if self.bias is not None: result += self.bias - - return result.to(x.dtype) + return result class FP16Linear(nn.Linear): diff --git a/python/llm/src/bigdl/llm/transformers/qlora.py b/python/llm/src/bigdl/llm/transformers/qlora.py index a52abd94..cb6b5d36 100644 --- a/python/llm/src/bigdl/llm/transformers/qlora.py +++ b/python/llm/src/bigdl/llm/transformers/qlora.py @@ -30,7 +30,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/training_args.py +# +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import torch from bigdl.llm.transformers.low_bit_linear import LowBitLinear @@ -193,6 +209,147 @@ class PeftModel: def patch_prepare_ipex(self, *args): return tuple(args) + +from transformers.utils import ( + requires_backends, + is_sagemaker_mp_enabled, + is_accelerate_available, + is_torch_xpu_available, + is_sagemaker_dp_enabled, + is_torch_tpu_available, + is_torch_npu_available) +from transformers.utils.generic import strtobool +from transformers.utils import cached_property +from transformers.training_args import logger, ParallelMode, DistributedType +import torch +import torch.distributed as dist +import os +import warnings +from datetime import timedelta + +if is_accelerate_available(): + from accelerate.state import AcceleratorState, PartialState + from accelerate.utils import DistributedType + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + + smp.init() + + +@cached_property +def _setup_devices(self) -> "torch.device": + requires_backends(self, ["torch"]) + logger.info("PyTorch: setting up devices") + if not is_sagemaker_mp_enabled(): + if not is_accelerate_available(min_version="0.20.1"): + invalidInputError( + False, + "Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: " + "Please run `pip install transformers[torch]` or `pip install accelerate -U`" + ) + AcceleratorState._reset_state(reset_partial_state=True) + self.distributed_state = None + if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ: + os.environ["ACCELERATE_USE_IPEX"] = "false" + if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")): + self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend) + self._n_gpu = 0 + elif is_sagemaker_mp_enabled(): + local_rank = smp.local_rank() + device = torch.device("cuda", local_rank) + self._n_gpu = 1 + torch.cuda.set_device(device) + elif is_torch_xpu_available() and "ACCELERATE_USE_XPU" not in os.environ: + os.environ["ACCELERATE_USE_XPU"] = "true" + self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout)) + # device = torch.device("xpu:0") + device = self.distributed_state.device + self._n_gpu = 1 + elif is_sagemaker_dp_enabled(): + self.distributed_state = PartialState(_use_sagemaker_dp=True) + self._n_gpu = 1 + elif self.deepspeed: + # Need to do similar for Accelerator init + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout)) + del os.environ["ACCELERATE_USE_DEEPSPEED"] + self._n_gpu = 1 + else: + self.distributed_state = PartialState( + backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout) + ) + self._n_gpu = 1 + if not is_sagemaker_mp_enabled(): + device = self.distributed_state.device + self.local_rank = self.distributed_state.local_process_index + if dist.is_available() and dist.is_initialized() and \ + self.parallel_mode != ParallelMode.DISTRIBUTED: + logger.warning( + "torch.distributed process group is initialized, " + "but parallel_mode != ParallelMode.DISTRIBUTED. " + "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" + ) + if is_torch_tpu_available(): + device = self.distributed_state.device + self._n_gpu = 0 + elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled(): + # Already set _n_gpu + pass + elif self.distributed_state.distributed_type == DistributedType.MULTI_XPU: + if "ACCELERATE_USE_XPU" not in os.environ: + os.environ["ACCELERATE_USE_XPU"] = "true" + # self._n_gpu = torch.xpu.device_count() + # device = torch.device("xpu:0") + # torch.xpu.set_device(device) + elif self.distributed_state.distributed_type == DistributedType.NO: + if self.use_mps_device: + warnings.warn( + "`use_mps_device` is deprecated and will be removed in" + " version 5.0 of 🤗 Transformers." + "`mps` device will be used by default if available similar" + " to the way `cuda` device is used." + "Therefore, no action from user is required. " + ) + if device.type != "mps": + invalidInputError(False, + ("Either you do not have an MPS-enabled device" + " on this machine or MacOS" + " version is not 12.3+ " + "or current PyTorch install was not built with MPS enabled.")) + if device.type == "mps": + self._n_gpu = 1 + elif self.use_cpu: + device = torch.device("cpu") + self._n_gpu = 0 + elif is_torch_xpu_available(): + device = torch.device("xpu:0") + torch.xpu.set_device(device) + self._n_gpu = 1 + elif is_torch_npu_available(): + device = torch.device("npu:0") + torch.npu.set_device(device) + self._n_gpu = 1 + else: + # if n_gpu is > 1 we'll use nn.DataParallel. + # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` + # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will + # trigger an error that a device index is missing. Index 0 takes into account the + # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` + # will use the first GPU in that env, i.e. GPU#1 + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # Sometimes the line in the postinit has not been run before we end up here, + # so just checking we're not at + # the default value. + self._n_gpu = torch.cuda.device_count() + if device.type == "cuda": + torch.cuda.set_device(device) + return device + # workaround a IPEX bug that prevents resume training in bf16 from accelerate import Accelerator Accelerator._prepare_ipex = patch_prepare_ipex + +# patch transformer for xpu DDP traing +from transformers import TrainingArguments +TrainingArguments._setup_devices = _setup_devices diff --git a/python/llm/src/bigdl/llm/transformers/xpu_customize_fwd.py b/python/llm/src/bigdl/llm/transformers/xpu_customize_fwd.py new file mode 100644 index 00000000..bfed60d4 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/xpu_customize_fwd.py @@ -0,0 +1,190 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/pytorch/pytorch/blob/v2.1.0/torch/cuda/amp/autocast_mode.py +# +# From PyTorch: + +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, +# Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +# From Caffe2: + +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. + +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. + +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. + +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain + +# All contributions by Cruise LLC: +# Copyright (c) 2022 Cruise LLC. +# All rights reserved. + +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. + +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. + +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +import collections +import functools + +import torch + +try: + import numpy as np + + HAS_NUMPY = True +except ModuleNotFoundError: + np = None # type: ignore[assignment] +from typing import Any + + +def _cast(value, dtype): + if isinstance(value, torch.Tensor): + is_eligible = ( + value.is_floating_point() + and value.is_xpu + and (value.dtype is not torch.float64) + ) + return value.to(dtype) if is_eligible else value + elif isinstance(value, (str, bytes)): + return value + elif HAS_NUMPY and isinstance(value, np.ndarray): + return value + elif isinstance(value, collections.abc.Mapping): + return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()} + elif isinstance(value, collections.abc.Iterable): + iterable = (_cast(v, dtype) for v in value) + if isinstance(value, (list, tuple)): + return type(value)(iterable) + else: + return iterable + else: + return value + + +def custom_fwd(fwd=None, *, cast_inputs=None): + """ + Helper decorator for ``forward`` methods of custom autograd functions (subclasses of + :class:`torch.autograd.Function`). See the :ref:`example page` + for more detail. + + Args: + cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, + when ``forward`` runs in an autocast-enabled region, casts incoming + floating-point CUDA Tensors to the target dtype (non-floating-point Tensors + are not affected), + then executes ``forward`` with autocast disabled. + If ``None``, ``forward``'s internal ops execute with the current autocast state. + + .. note:: + If the decorated ``forward`` is called outside an autocast-enabled region, + :func:`custom_fwd` is a no-op and ``cast_inputs`` has no effect. + """ + if fwd is None: + return functools.partial(custom_fwd, cast_inputs=cast_inputs) + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + args[0]._dtype = torch.xpu.get_autocast_xpu_dtype() + if cast_inputs is None: + args[0]._fwd_used_autocast = torch.xpu.is_autocast_xpu_enabled() + return fwd(*args, **kwargs) + else: + autocast_context = torch.xpu.is_autocast_xpu_enabled() + args[0]._fwd_used_autocast = False + if autocast_context: + with torch.xpu.autocast(enabled=False): + return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs)) + else: + return fwd(*args, **kwargs) + + return decorate_fwd + + +# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate +# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match +# cast_inputs supplied to custom_fwd. +def custom_bwd(bwd): + """ + Helper decorator for backward methods of custom autograd functions (subclasses of + :class:`torch.autograd.Function`). + Ensures that ``backward`` executes with the same autocast state as ``forward``. + See the :ref:`example page` for more detail. + """ + + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.xpu.autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype): + return bwd(*args, **kwargs) + + return decorate_bwd