deepspeed zero3 QLoRA finetuning (#11625)

* deepspeed zero3 QLoRA finetuning

* Update convert.py

* Update low_bit_linear.py

* Update utils.py

* Update qlora_finetune_llama2_13b_arch_2_card.sh

* Update low_bit_linear.py

* Update alpaca_qlora_finetuning.py

* Update low_bit_linear.py

* Update utils.py

* Update convert.py

* Update alpaca_qlora_finetuning.py

* Update alpaca_qlora_finetuning.py

* Update low_bit_linear.py

* Update deepspeed_zero3.json

* Update qlora_finetune_llama2_13b_arch_2_card.sh

* Update low_bit_linear.py

* Update low_bit_linear.py

* Update utils.py

* fix style

* fix style

* Update alpaca_qlora_finetuning.py

* Update qlora_finetune_llama2_13b_arch_2_card.sh

* Update convert.py

* Update low_bit_linear.py

* Update model.py

* Update alpaca_qlora_finetuning.py

* Update low_bit_linear.py

* Update low_bit_linear.py

* Update low_bit_linear.py
This commit is contained in:
Heyang Sun 2024-08-13 16:15:29 +08:00 committed by GitHub
parent a184b120c9
commit 70c828b87c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 119 additions and 14 deletions

View file

@ -144,6 +144,14 @@ def train(
prompter = Prompter(prompt_template_name) prompter = Prompter(prompt_template_name)
if deepspeed is not None and "zero3" in deepspeed:
from ipex_llm.transformers.utils \
import _constant_buffered_norm2
from ipex_llm.llm_patching import replace_attr
import deepspeed as ds
replace_attr(ds.runtime.zero.stage3.DeepSpeedZeroOptimizer_Stage3,
"_constant_buffered_norm2", _constant_buffered_norm2)
device_map = "auto" device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1)) world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1 ddp = world_size != 1
@ -161,7 +169,7 @@ def train(
optimize_model=False, optimize_model=False,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"], modules_to_not_convert=["lm_head"],
trust_remote_code=True, trust_remote_code=True
) )
else: else:
# According to the QLoRA paper, using "nf4" could yield better model quality than "int4" # According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
@ -186,9 +194,10 @@ def train(
# # device_map=device_map, # # device_map=device_map,
# modules_to_not_convert=["lm_head"], # modules_to_not_convert=["lm_head"],
# ) # )
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}") if deepspeed is not None and not "zero3" in deepspeed:
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}') print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}") model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}") print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}")

View file

@ -0,0 +1,15 @@
{
"zero_optimization": {
"stage": 3,
"contiguous_gradients": true,
"overlap_comm": true,
"offload_optimizer": {"device": "cpu"}
},
"bf16": {
"enabled": true
},
"world_size": 2,
"train_batch_size": 32,
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 8
}

View file

@ -0,0 +1,41 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=29503
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
export CCL_ZE_IPC_EXCHANGE=sockets
export UR_L0_IN_ORDER_BARRIER_BY_SIGNAL=0
basekit_root=/opt/intel/oneapi
source $basekit_root/setvars.sh --force
source $basekit_root/ccl/latest/env/vars.sh --force
NUM_GPUS=2 # number of used GPU
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
export TORCH_LLM_ALLREDUCE=0 # Different from PVC
export DS_SKIP_CUDA_CHECK=1
mpirun -n $NUM_GPUS \
python -u ./alpaca_qlora_finetuning.py \
--base_model "meta-llama/Llama-2-13b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./ipex-llm-qlora-alpaca" \
--gradient_checkpointing True \
--micro_batch_size 2 \
--batch_size 32 \
--deepspeed ./deepspeed_zero3.json

View file

@ -229,6 +229,13 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK) dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
scale = torch.empty(n // k, dtype=torch.float32, scale = torch.empty(n // k, dtype=torch.float32,
device=device) device=device)
elif qtype == NF4:
# Deepspeed zero3 requires unified dtype,
# thus here uses bfloat16 consistent to other layers
# dst_size above is computed based on uint8, and for bfloat16,
# buffer size should be half
dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16,
device=device)
else: else:
dst_tensor = torch.empty(dst_size, dtype=torch.uint8, dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
device=device) device=device)
@ -260,12 +267,15 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int): def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int):
if qtype == NF4:
invalidInputError(tensor.dtype == torch.uint8, invalidInputError(tensor.dtype == torch.bfloat16,
"Input tensor must be uint8") "NF4 Input tensor must be bfloat16")
else:
invalidInputError(tensor.dtype == torch.uint8,
"Input tensor except NF4 must be uint8")
invalidInputError(tensor.device == torch.device('cpu'), invalidInputError(tensor.device == torch.device('cpu'),
"Input tensor must be uint8") "Input tensor must be on cpu")
src = ctypes.c_void_p(tensor.data.data_ptr()) src = ctypes.c_void_p(tensor.data.data_ptr())
@ -370,7 +380,6 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
# Rename to FP4Params to trigger initializing # Rename to FP4Params to trigger initializing
# the params layer with all parameters on the CPU # the params layer with all parameters on the CPU
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333
class FP4Params(torch.nn.Parameter): class FP4Params(torch.nn.Parameter):
def __new__(cls, def __new__(cls,
data=None, data=None,
@ -582,7 +591,13 @@ class MatMulLowBit(torch.autograd.Function):
def forward(ctx, A, weight, input_seq_size): def forward(ctx, A, weight, input_seq_size):
ctx.is_empty = False ctx.is_empty = False
import xe_linear import xe_linear
result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size) if weight.qtype == NF4:
result = xe_linear.forward_new(A,
weight.data.view(torch.uint8),
weight.qtype,
input_seq_size)
else:
result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size)
if any(ctx.needs_input_grad[:2]): if any(ctx.needs_input_grad[:2]):
ctx.tensors = (A, weight) ctx.tensors = (A, weight)
else: else:
@ -602,7 +617,12 @@ class MatMulLowBit(torch.autograd.Function):
if req_gradA: if req_gradA:
if torch.xpu.is_autocast_xpu_enabled(): if torch.xpu.is_autocast_xpu_enabled():
grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype()) grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype())
dequant_weight = xe_linear.dequant(A, weight.data, weight.qtype) if weight.qtype == NF4:
dequant_weight = xe_linear.dequant(A,
weight.data.view(torch.uint8),
weight.qtype)
else:
dequant_weight = xe_linear.dequant(A, weight.data, weight.qtype)
grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape)) grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape))
return grad_A, grad_weight, None return grad_A, grad_weight, None
@ -737,9 +757,16 @@ class LowBitLinear(nn.Linear):
if x_2d.requires_grad: if x_2d.requires_grad:
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size) result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
else: else:
result = xe_linear.forward_new(x_2d, self.weight.data, if self.weight.qtype == NF4:
self.weight.qtype, result = xe_linear.forward_new(x_2d,
input_seq_size) self.weight.data.view(torch.uint8),
self.weight.qtype,
input_seq_size)
else:
result = xe_linear.forward_new(x_2d,
self.weight.data,
self.weight.qtype,
input_seq_size)
elif self.enable_xetla: elif self.enable_xetla:
x_2d = x_2d.half() x_2d = x_2d.half()
result = xe_linear.mm_xetla(x_2d, self.weight.data, self.qtype) result = xe_linear.mm_xetla(x_2d, self.weight.data, self.qtype)

View file

@ -382,3 +382,16 @@ def check_hidden_size(qtype, hidden_size):
"required for fq6_k - using fallback quantization fp6.") "required for fq6_k - using fallback quantization fp6.")
return ggml_tensor_qtype["fp6"] return ggml_tensor_qtype["fp6"]
return qtype return qtype
# Arc platfrom does not support FP64,
# Disable FP64 in DeepSpeedZeroOptimizer_Stage3's _constant_buffered_norm2 method
# https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage3.py#L1365
def _constant_buffered_norm2(self, input, buffer_size=250000000):
norm = None
for part in input.view(-1).split(buffer_size):
if norm is None:
norm = part.data.norm(2)**2.0
else:
norm += part.data.norm(2)**2.0
return norm**0.5