From 70c828b87cce149655ecbb58f3033a6fb3b9417f Mon Sep 17 00:00:00 2001 From: Heyang Sun <60865256+Uxito-Ada@users.noreply.github.com> Date: Tue, 13 Aug 2024 16:15:29 +0800 Subject: [PATCH] deepspeed zero3 QLoRA finetuning (#11625) * deepspeed zero3 QLoRA finetuning * Update convert.py * Update low_bit_linear.py * Update utils.py * Update qlora_finetune_llama2_13b_arch_2_card.sh * Update low_bit_linear.py * Update alpaca_qlora_finetuning.py * Update low_bit_linear.py * Update utils.py * Update convert.py * Update alpaca_qlora_finetuning.py * Update alpaca_qlora_finetuning.py * Update low_bit_linear.py * Update deepspeed_zero3.json * Update qlora_finetune_llama2_13b_arch_2_card.sh * Update low_bit_linear.py * Update low_bit_linear.py * Update utils.py * fix style * fix style * Update alpaca_qlora_finetuning.py * Update qlora_finetune_llama2_13b_arch_2_card.sh * Update convert.py * Update low_bit_linear.py * Update model.py * Update alpaca_qlora_finetuning.py * Update low_bit_linear.py * Update low_bit_linear.py * Update low_bit_linear.py --- .../alpaca-qlora/alpaca_qlora_finetuning.py | 17 +++++-- .../QLoRA/alpaca-qlora/deepspeed_zero3.json | 15 ++++++ .../qlora_finetune_llama2_13b_arch_2_card.sh | 41 ++++++++++++++++ .../ipex_llm/transformers/low_bit_linear.py | 47 +++++++++++++++---- python/llm/src/ipex_llm/transformers/utils.py | 13 +++++ 5 files changed, 119 insertions(+), 14 deletions(-) create mode 100644 python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/deepspeed_zero3.json create mode 100644 python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/qlora_finetune_llama2_13b_arch_2_card.sh diff --git a/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/alpaca_qlora_finetuning.py b/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/alpaca_qlora_finetuning.py index 61916fff..c1df15db 100644 --- a/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/alpaca_qlora_finetuning.py +++ b/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/alpaca_qlora_finetuning.py @@ -144,6 +144,14 @@ def train( prompter = Prompter(prompt_template_name) + if deepspeed is not None and "zero3" in deepspeed: + from ipex_llm.transformers.utils \ + import _constant_buffered_norm2 + from ipex_llm.llm_patching import replace_attr + import deepspeed as ds + replace_attr(ds.runtime.zero.stage3.DeepSpeedZeroOptimizer_Stage3, + "_constant_buffered_norm2", _constant_buffered_norm2) + device_map = "auto" world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 @@ -161,7 +169,7 @@ def train( optimize_model=False, torch_dtype=torch.bfloat16, modules_to_not_convert=["lm_head"], - trust_remote_code=True, + trust_remote_code=True ) else: # According to the QLoRA paper, using "nf4" could yield better model quality than "int4" @@ -186,9 +194,10 @@ def train( # # device_map=device_map, # modules_to_not_convert=["lm_head"], # ) - print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}") - model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}') - print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}") + if deepspeed is not None and not "zero3" in deepspeed: + print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}") + model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}') + print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}") tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}") diff --git a/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/deepspeed_zero3.json b/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/deepspeed_zero3.json new file mode 100644 index 00000000..7ee8a787 --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/deepspeed_zero3.json @@ -0,0 +1,15 @@ +{ + "zero_optimization": { + "stage": 3, + "contiguous_gradients": true, + "overlap_comm": true, + "offload_optimizer": {"device": "cpu"} + }, + "bf16": { + "enabled": true + }, + "world_size": 2, + "train_batch_size": 32, + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 8 +} diff --git a/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/qlora_finetune_llama2_13b_arch_2_card.sh b/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/qlora_finetune_llama2_13b_arch_2_card.sh new file mode 100644 index 00000000..ba5a11b0 --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/qlora_finetune_llama2_13b_arch_2_card.sh @@ -0,0 +1,41 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +export MASTER_ADDR=127.0.0.1 +export MASTER_PORT=29503 +export FI_PROVIDER=tcp +export CCL_ATL_TRANSPORT=ofi +export CCL_ZE_IPC_EXCHANGE=sockets +export UR_L0_IN_ORDER_BARRIER_BY_SIGNAL=0 +basekit_root=/opt/intel/oneapi +source $basekit_root/setvars.sh --force +source $basekit_root/ccl/latest/env/vars.sh --force + +NUM_GPUS=2 # number of used GPU +export USE_XETLA=OFF +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2 +export TORCH_LLM_ALLREDUCE=0 # Different from PVC +export DS_SKIP_CUDA_CHECK=1 + +mpirun -n $NUM_GPUS \ + python -u ./alpaca_qlora_finetuning.py \ + --base_model "meta-llama/Llama-2-13b-hf" \ + --data_path "yahma/alpaca-cleaned" \ + --output_dir "./ipex-llm-qlora-alpaca" \ + --gradient_checkpointing True \ + --micro_batch_size 2 \ + --batch_size 32 \ + --deepspeed ./deepspeed_zero3.json diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index e182394a..aacd288c 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -229,6 +229,13 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK) scale = torch.empty(n // k, dtype=torch.float32, device=device) + elif qtype == NF4: + # Deepspeed zero3 requires unified dtype, + # thus here uses bfloat16 consistent to other layers + # dst_size above is computed based on uint8, and for bfloat16, + # buffer size should be half + dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16, + device=device) else: dst_tensor = torch.empty(dst_size, dtype=torch.uint8, device=device) @@ -260,12 +267,15 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int): - - invalidInputError(tensor.dtype == torch.uint8, - "Input tensor must be uint8") + if qtype == NF4: + invalidInputError(tensor.dtype == torch.bfloat16, + "NF4 Input tensor must be bfloat16") + else: + invalidInputError(tensor.dtype == torch.uint8, + "Input tensor except NF4 must be uint8") invalidInputError(tensor.device == torch.device('cpu'), - "Input tensor must be uint8") + "Input tensor must be on cpu") src = ctypes.c_void_p(tensor.data.data_ptr()) @@ -370,7 +380,6 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int): # Rename to FP4Params to trigger initializing # the params layer with all parameters on the CPU -# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333 class FP4Params(torch.nn.Parameter): def __new__(cls, data=None, @@ -582,7 +591,13 @@ class MatMulLowBit(torch.autograd.Function): def forward(ctx, A, weight, input_seq_size): ctx.is_empty = False import xe_linear - result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size) + if weight.qtype == NF4: + result = xe_linear.forward_new(A, + weight.data.view(torch.uint8), + weight.qtype, + input_seq_size) + else: + result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size) if any(ctx.needs_input_grad[:2]): ctx.tensors = (A, weight) else: @@ -602,7 +617,12 @@ class MatMulLowBit(torch.autograd.Function): if req_gradA: if torch.xpu.is_autocast_xpu_enabled(): grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype()) - dequant_weight = xe_linear.dequant(A, weight.data, weight.qtype) + if weight.qtype == NF4: + dequant_weight = xe_linear.dequant(A, + weight.data.view(torch.uint8), + weight.qtype) + else: + dequant_weight = xe_linear.dequant(A, weight.data, weight.qtype) grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape)) return grad_A, grad_weight, None @@ -737,9 +757,16 @@ class LowBitLinear(nn.Linear): if x_2d.requires_grad: result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size) else: - result = xe_linear.forward_new(x_2d, self.weight.data, - self.weight.qtype, - input_seq_size) + if self.weight.qtype == NF4: + result = xe_linear.forward_new(x_2d, + self.weight.data.view(torch.uint8), + self.weight.qtype, + input_seq_size) + else: + result = xe_linear.forward_new(x_2d, + self.weight.data, + self.weight.qtype, + input_seq_size) elif self.enable_xetla: x_2d = x_2d.half() result = xe_linear.mm_xetla(x_2d, self.weight.data, self.qtype) diff --git a/python/llm/src/ipex_llm/transformers/utils.py b/python/llm/src/ipex_llm/transformers/utils.py index cf8c5612..5cd706c2 100644 --- a/python/llm/src/ipex_llm/transformers/utils.py +++ b/python/llm/src/ipex_llm/transformers/utils.py @@ -382,3 +382,16 @@ def check_hidden_size(qtype, hidden_size): "required for fq6_k - using fallback quantization fp6.") return ggml_tensor_qtype["fp6"] return qtype + + +# Arc platfrom does not support FP64, +# Disable FP64 in DeepSpeedZeroOptimizer_Stage3's _constant_buffered_norm2 method +# https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage3.py#L1365 +def _constant_buffered_norm2(self, input, buffer_size=250000000): + norm = None + for part in input.view(-1).split(buffer_size): + if norm is None: + norm = part.data.norm(2)**2.0 + else: + norm += part.data.norm(2)**2.0 + return norm**0.5