fix qlora finetune example (#12769)

This commit is contained in:
Yishuo Wang 2025-02-06 11:18:28 +08:00 committed by GitHub
parent 094a25b740
commit 9697197f3e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 55 additions and 39 deletions

View file

@ -17,9 +17,9 @@ conda create -n llm python=3.11
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install transformers==4.36.0 datasets
pip install transformers==4.45.0 "trl<0.12.0" datasets
pip install peft==0.10.0
pip install bitsandbytes scipy
pip install bitsandbytes==0.45.1 scipy
```
### 2. Configures OneAPI environment variables

View file

@ -18,7 +18,7 @@ import torch
import os
import transformers
from transformers import LlamaTokenizer
from transformers import AutoTokenizer
from peft import LoraConfig
from transformers import BitsAndBytesConfig
from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
@ -43,13 +43,13 @@ if __name__ == "__main__":
args = parser.parse_args()
model_path = args.repo_id_or_model_path
dataset_path = args.dataset
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if dataset_path.endswith(".json") or dataset_path.endswith(".jsonl"):
data = load_dataset("json", data_files=dataset_path)
else:
data = load_dataset(dataset_path)
# For illustration purpose, only use part of data to train
data = data["train"].train_test_split(train_size=0.1, shuffle=False)
@ -57,7 +57,7 @@ if __name__ == "__main__":
prompter = Prompter("alpaca")
train_data, _ = get_train_val_data(data, tokenizer, prompter, train_on_inputs=True,
add_eos_token=False, cutoff_len=256, val_set_size=0, seed=42)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
@ -79,11 +79,11 @@ if __name__ == "__main__":
# model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
r=8,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)

View file

@ -17,9 +17,9 @@ conda create -n llm python=3.11
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install transformers==4.36.0 datasets
pip install transformers==4.45.0 "trl<0.12.0" datasets
pip install peft==0.10.0
pip install bitsandbytes scipy trl==0.9.6
pip install bitsandbytes==0.45.1 scipy
```
### 2. Configures OneAPI environment variables
@ -39,13 +39,13 @@ python ./qlora_finetuning.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH
{'loss': 3.1854, 'learning_rate': 1.7777777777777777e-05, 'epoch': 0.03}
{'loss': 3.0359, 'learning_rate': 1.555555555555556e-05, 'epoch': 0.05}
{'loss': 2.9661, 'learning_rate': 1.3333333333333333e-05, 'epoch': 0.06}
{'loss': 2.7779, 'learning_rate': 1.1111111111111113e-05, 'epoch': 0.08}
{'loss': 2.7779, 'learning_rate': 1.1111111111111113e-05, 'epoch': 0.08}
{'loss': 2.7795, 'learning_rate': 8.888888888888888e-06, 'epoch': 0.09}
{'loss': 2.5149, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.11}
{'loss': 2.5759, 'learning_rate': 4.444444444444444e-06, 'epoch': 0.12}
{'loss': 2.5976, 'learning_rate': 2.222222222222222e-06, 'epoch': 0.14}
{'loss': 2.5744, 'learning_rate': 0.0, 'epoch': 0.15}
{'train_runtime': 116.1914, 'train_samples_per_second': 6.885, 'train_steps_per_second': 1.721, 'train_loss': 2.819730052947998, 'epoch': 0.15}
{'train_runtime': 116.1914, 'train_samples_per_second': 6.885, 'train_steps_per_second': 1.721, 'train_loss': 2.819730052947998, 'epoch': 0.15}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:56<00:00, 1.72it/s]
TrainOutput(global_step=200, training_loss=2.819730052947998, metrics={'train_runtime': 116.1914, 'train_samples_per_second': 6.885, 'train_steps_per_second': 1.721, 'train_loss': 2.819730052947998, 'epoch': 0.15})
```

View file

@ -18,7 +18,7 @@ import torch
import os
import transformers
from transformers import LlamaTokenizer
from transformers import AutoTokenizer
from peft import LoraConfig
from transformers import BitsAndBytesConfig
from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
@ -44,7 +44,7 @@ if __name__ == "__main__":
args = parser.parse_args()
model_path = args.repo_id_or_model_path
dataset_path = args.dataset
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Avoid tokenizer doesn't have a padding token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
@ -53,7 +53,7 @@ if __name__ == "__main__":
data = load_dataset("json", data_files=dataset_path)
else:
data = load_dataset(dataset_path)
# For illustration purpose, only use part of data to train
data = data["train"].train_test_split(train_size=0.1, shuffle=False)
@ -82,11 +82,11 @@ if __name__ == "__main__":
# it will slowdown the training speed
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
r=8,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)

View file

@ -51,7 +51,8 @@ from torch import Tensor, dtype, nn
from operator import mul
from functools import reduce
from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_name
from ipex_llm.transformers.utils import is_autocast_enabled, get_autocast_dtype
from ipex_llm.transformers.utils import get_xpu_device_name
from ipex_llm.transformers.convert import is_deepspeed_available, get_use_vllm
T = TypeVar("T", bound="torch.nn.Module")
@ -527,8 +528,8 @@ class MatMulLowBit(torch.autograd.Function):
A, weight = ctx.tensors
grad_A, grad_weight = None, None
if req_gradA:
if torch.xpu.is_autocast_xpu_enabled():
grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype())
if is_autocast_enabled("xpu"):
grad_output = grad_output.to(get_autocast_dtype("xpu"))
if weight.qtype == NF4:
dequant_weight = xe_linear.dequant(A,
weight.data.view(torch.uint8),
@ -615,7 +616,7 @@ class LowBitLinear(nn.Linear):
is_training = self.training and not torch.is_inference_mode_enabled()
if is_training:
# below logic is only for training
autocast_dtype = get_autocast_dtype(x)
autocast_dtype = get_autocast_dtype(x.device.type)
if self.compute_dtype is not None and x.device.type == "xpu":
x = x.to(self.compute_dtype) # solve GC issue for unlora module
elif autocast_dtype is not None:

View file

@ -109,7 +109,7 @@ class LoraLowBitLinear(Module, LoraLayer):
self.qa_pool = torch.nn.Identity()
def forward(self, x: torch.Tensor):
autocast_dtype = get_autocast_dtype(x)
autocast_dtype = get_autocast_dtype(x.device.type)
if x.device.type == "xpu":
# force to use bf16 on gpu
x = x.to(torch.bfloat16)
@ -177,7 +177,7 @@ class LoraBF16Linear(Module, LoraLayer):
self.is_target_conv_1d_layer = is_target_conv_1d_layer
def forward(self, x: torch.Tensor):
autocast_dtype = get_autocast_dtype(x)
autocast_dtype = get_autocast_dtype(x.device.type)
if x.device.type == "xpu":
# force to use bf16 on gpu
x = x.to(torch.bfloat16)

View file

@ -138,26 +138,39 @@ def fix_key(key):
return key
def get_autocast_dtype(x):
def is_autocast_enabled(device_type: str):
if torch.__version__ >= '2.3':
if torch.is_autocast_enabled(x.device.type):
return torch.get_autocast_dtype(x.device.type)
return torch.is_autocast_enabled(device_type)
else:
if device_type == "xpu":
return torch.xpu.is_autocast_xpu_enabled()
elif device_type == "cpu":
return torch.is_autocast_cpu_enabled()
else:
invalidInputError(False,
f"Device type {device_type} is not supported.")
def get_autocast_dtype(device_type: str):
if torch.__version__ >= '2.3':
if torch.is_autocast_enabled(device_type):
return torch.get_autocast_dtype(device_type)
else:
return None
else:
if x.device.type == "xpu":
if device_type == "xpu":
if torch.xpu.is_autocast_xpu_enabled():
return torch.xpu.get_autocast_xpu_dtype()
else:
return None
elif x.device.type == "cpu":
elif device_type == "cpu":
if torch.is_autocast_cpu_enabled():
return torch.get_autocast_cpu_dtype()
else:
return None
else:
invalidInputError(False,
f"Device {x.device} is not supported.")
f"Device type {device_type} is not supported.")
def get_xpu_device_name(device: torch.device):

View file

@ -107,6 +107,8 @@ except ModuleNotFoundError:
np = None # type: ignore[assignment]
from typing import Any
from ipex_llm.transformers.utils import is_autocast_enabled, get_autocast_dtype
def _cast(value, dtype):
if isinstance(value, torch.Tensor):
@ -155,12 +157,12 @@ def custom_fwd(fwd=None, *, cast_inputs=None):
@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
args[0]._dtype = torch.xpu.get_autocast_xpu_dtype()
args[0]._dtype = get_autocast_dtype("xpu")
if cast_inputs is None:
args[0]._fwd_used_autocast = torch.xpu.is_autocast_xpu_enabled()
args[0]._fwd_used_autocast = is_autocast_enabled("xpu")
return fwd(*args, **kwargs)
else:
autocast_context = torch.xpu.is_autocast_xpu_enabled()
autocast_context = is_autocast_enabled("xpu")
args[0]._fwd_used_autocast = False
if autocast_context:
with torch.xpu.autocast(enabled=False):
@ -184,7 +186,7 @@ def custom_bwd(bwd):
@functools.wraps(bwd)
def decorate_bwd(*args, **kwargs):
with torch.xpu.autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
with torch.autocast("xpu", enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
return bwd(*args, **kwargs)
return decorate_bwd