deepspeed zero3 QLoRA finetuning (#11625)
* deepspeed zero3 QLoRA finetuning * Update convert.py * Update low_bit_linear.py * Update utils.py * Update qlora_finetune_llama2_13b_arch_2_card.sh * Update low_bit_linear.py * Update alpaca_qlora_finetuning.py * Update low_bit_linear.py * Update utils.py * Update convert.py * Update alpaca_qlora_finetuning.py * Update alpaca_qlora_finetuning.py * Update low_bit_linear.py * Update deepspeed_zero3.json * Update qlora_finetune_llama2_13b_arch_2_card.sh * Update low_bit_linear.py * Update low_bit_linear.py * Update utils.py * fix style * fix style * Update alpaca_qlora_finetuning.py * Update qlora_finetune_llama2_13b_arch_2_card.sh * Update convert.py * Update low_bit_linear.py * Update model.py * Update alpaca_qlora_finetuning.py * Update low_bit_linear.py * Update low_bit_linear.py * Update low_bit_linear.py
This commit is contained in:
parent
a184b120c9
commit
70c828b87c
5 changed files with 119 additions and 14 deletions
|
|
@ -144,6 +144,14 @@ def train(
|
||||||
|
|
||||||
prompter = Prompter(prompt_template_name)
|
prompter = Prompter(prompt_template_name)
|
||||||
|
|
||||||
|
if deepspeed is not None and "zero3" in deepspeed:
|
||||||
|
from ipex_llm.transformers.utils \
|
||||||
|
import _constant_buffered_norm2
|
||||||
|
from ipex_llm.llm_patching import replace_attr
|
||||||
|
import deepspeed as ds
|
||||||
|
replace_attr(ds.runtime.zero.stage3.DeepSpeedZeroOptimizer_Stage3,
|
||||||
|
"_constant_buffered_norm2", _constant_buffered_norm2)
|
||||||
|
|
||||||
device_map = "auto"
|
device_map = "auto"
|
||||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
ddp = world_size != 1
|
ddp = world_size != 1
|
||||||
|
|
@ -161,7 +169,7 @@ def train(
|
||||||
optimize_model=False,
|
optimize_model=False,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
modules_to_not_convert=["lm_head"],
|
modules_to_not_convert=["lm_head"],
|
||||||
trust_remote_code=True,
|
trust_remote_code=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
|
# According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
|
||||||
|
|
@ -186,6 +194,7 @@ def train(
|
||||||
# # device_map=device_map,
|
# # device_map=device_map,
|
||||||
# modules_to_not_convert=["lm_head"],
|
# modules_to_not_convert=["lm_head"],
|
||||||
# )
|
# )
|
||||||
|
if deepspeed is not None and not "zero3" in deepspeed:
|
||||||
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
|
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
|
||||||
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
|
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
|
||||||
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")
|
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
{
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"overlap_comm": true,
|
||||||
|
"offload_optimizer": {"device": "cpu"}
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"world_size": 2,
|
||||||
|
"train_batch_size": 32,
|
||||||
|
"train_micro_batch_size_per_gpu": 2,
|
||||||
|
"gradient_accumulation_steps": 8
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
export MASTER_ADDR=127.0.0.1
|
||||||
|
export MASTER_PORT=29503
|
||||||
|
export FI_PROVIDER=tcp
|
||||||
|
export CCL_ATL_TRANSPORT=ofi
|
||||||
|
export CCL_ZE_IPC_EXCHANGE=sockets
|
||||||
|
export UR_L0_IN_ORDER_BARRIER_BY_SIGNAL=0
|
||||||
|
basekit_root=/opt/intel/oneapi
|
||||||
|
source $basekit_root/setvars.sh --force
|
||||||
|
source $basekit_root/ccl/latest/env/vars.sh --force
|
||||||
|
|
||||||
|
NUM_GPUS=2 # number of used GPU
|
||||||
|
export USE_XETLA=OFF
|
||||||
|
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
|
||||||
|
export TORCH_LLM_ALLREDUCE=0 # Different from PVC
|
||||||
|
export DS_SKIP_CUDA_CHECK=1
|
||||||
|
|
||||||
|
mpirun -n $NUM_GPUS \
|
||||||
|
python -u ./alpaca_qlora_finetuning.py \
|
||||||
|
--base_model "meta-llama/Llama-2-13b-hf" \
|
||||||
|
--data_path "yahma/alpaca-cleaned" \
|
||||||
|
--output_dir "./ipex-llm-qlora-alpaca" \
|
||||||
|
--gradient_checkpointing True \
|
||||||
|
--micro_batch_size 2 \
|
||||||
|
--batch_size 32 \
|
||||||
|
--deepspeed ./deepspeed_zero3.json
|
||||||
|
|
@ -229,6 +229,13 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||||
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
|
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
|
||||||
scale = torch.empty(n // k, dtype=torch.float32,
|
scale = torch.empty(n // k, dtype=torch.float32,
|
||||||
device=device)
|
device=device)
|
||||||
|
elif qtype == NF4:
|
||||||
|
# Deepspeed zero3 requires unified dtype,
|
||||||
|
# thus here uses bfloat16 consistent to other layers
|
||||||
|
# dst_size above is computed based on uint8, and for bfloat16,
|
||||||
|
# buffer size should be half
|
||||||
|
dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16,
|
||||||
|
device=device)
|
||||||
else:
|
else:
|
||||||
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
|
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
@ -260,12 +267,15 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||||
|
|
||||||
|
|
||||||
def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int):
|
def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int):
|
||||||
|
if qtype == NF4:
|
||||||
|
invalidInputError(tensor.dtype == torch.bfloat16,
|
||||||
|
"NF4 Input tensor must be bfloat16")
|
||||||
|
else:
|
||||||
invalidInputError(tensor.dtype == torch.uint8,
|
invalidInputError(tensor.dtype == torch.uint8,
|
||||||
"Input tensor must be uint8")
|
"Input tensor except NF4 must be uint8")
|
||||||
|
|
||||||
invalidInputError(tensor.device == torch.device('cpu'),
|
invalidInputError(tensor.device == torch.device('cpu'),
|
||||||
"Input tensor must be uint8")
|
"Input tensor must be on cpu")
|
||||||
|
|
||||||
src = ctypes.c_void_p(tensor.data.data_ptr())
|
src = ctypes.c_void_p(tensor.data.data_ptr())
|
||||||
|
|
||||||
|
|
@ -370,7 +380,6 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
|
||||||
|
|
||||||
# Rename to FP4Params to trigger initializing
|
# Rename to FP4Params to trigger initializing
|
||||||
# the params layer with all parameters on the CPU
|
# the params layer with all parameters on the CPU
|
||||||
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333
|
|
||||||
class FP4Params(torch.nn.Parameter):
|
class FP4Params(torch.nn.Parameter):
|
||||||
def __new__(cls,
|
def __new__(cls,
|
||||||
data=None,
|
data=None,
|
||||||
|
|
@ -582,6 +591,12 @@ class MatMulLowBit(torch.autograd.Function):
|
||||||
def forward(ctx, A, weight, input_seq_size):
|
def forward(ctx, A, weight, input_seq_size):
|
||||||
ctx.is_empty = False
|
ctx.is_empty = False
|
||||||
import xe_linear
|
import xe_linear
|
||||||
|
if weight.qtype == NF4:
|
||||||
|
result = xe_linear.forward_new(A,
|
||||||
|
weight.data.view(torch.uint8),
|
||||||
|
weight.qtype,
|
||||||
|
input_seq_size)
|
||||||
|
else:
|
||||||
result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size)
|
result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size)
|
||||||
if any(ctx.needs_input_grad[:2]):
|
if any(ctx.needs_input_grad[:2]):
|
||||||
ctx.tensors = (A, weight)
|
ctx.tensors = (A, weight)
|
||||||
|
|
@ -602,6 +617,11 @@ class MatMulLowBit(torch.autograd.Function):
|
||||||
if req_gradA:
|
if req_gradA:
|
||||||
if torch.xpu.is_autocast_xpu_enabled():
|
if torch.xpu.is_autocast_xpu_enabled():
|
||||||
grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype())
|
grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype())
|
||||||
|
if weight.qtype == NF4:
|
||||||
|
dequant_weight = xe_linear.dequant(A,
|
||||||
|
weight.data.view(torch.uint8),
|
||||||
|
weight.qtype)
|
||||||
|
else:
|
||||||
dequant_weight = xe_linear.dequant(A, weight.data, weight.qtype)
|
dequant_weight = xe_linear.dequant(A, weight.data, weight.qtype)
|
||||||
grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape))
|
grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape))
|
||||||
|
|
||||||
|
|
@ -737,7 +757,14 @@ class LowBitLinear(nn.Linear):
|
||||||
if x_2d.requires_grad:
|
if x_2d.requires_grad:
|
||||||
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
|
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
|
||||||
else:
|
else:
|
||||||
result = xe_linear.forward_new(x_2d, self.weight.data,
|
if self.weight.qtype == NF4:
|
||||||
|
result = xe_linear.forward_new(x_2d,
|
||||||
|
self.weight.data.view(torch.uint8),
|
||||||
|
self.weight.qtype,
|
||||||
|
input_seq_size)
|
||||||
|
else:
|
||||||
|
result = xe_linear.forward_new(x_2d,
|
||||||
|
self.weight.data,
|
||||||
self.weight.qtype,
|
self.weight.qtype,
|
||||||
input_seq_size)
|
input_seq_size)
|
||||||
elif self.enable_xetla:
|
elif self.enable_xetla:
|
||||||
|
|
|
||||||
|
|
@ -382,3 +382,16 @@ def check_hidden_size(qtype, hidden_size):
|
||||||
"required for fq6_k - using fallback quantization fp6.")
|
"required for fq6_k - using fallback quantization fp6.")
|
||||||
return ggml_tensor_qtype["fp6"]
|
return ggml_tensor_qtype["fp6"]
|
||||||
return qtype
|
return qtype
|
||||||
|
|
||||||
|
|
||||||
|
# Arc platfrom does not support FP64,
|
||||||
|
# Disable FP64 in DeepSpeedZeroOptimizer_Stage3's _constant_buffered_norm2 method
|
||||||
|
# https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage3.py#L1365
|
||||||
|
def _constant_buffered_norm2(self, input, buffer_size=250000000):
|
||||||
|
norm = None
|
||||||
|
for part in input.view(-1).split(buffer_size):
|
||||||
|
if norm is None:
|
||||||
|
norm = part.data.norm(2)**2.0
|
||||||
|
else:
|
||||||
|
norm += part.data.norm(2)**2.0
|
||||||
|
return norm**0.5
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue