diff --git a/python/llm/example/GPU/LLM-Finetuning/QLoRA/simple-example/README.md b/python/llm/example/GPU/LLM-Finetuning/QLoRA/simple-example/README.md index 05806233..420359d0 100644 --- a/python/llm/example/GPU/LLM-Finetuning/QLoRA/simple-example/README.md +++ b/python/llm/example/GPU/LLM-Finetuning/QLoRA/simple-example/README.md @@ -17,9 +17,9 @@ conda create -n llm python=3.11 conda activate llm # below command will install intel_extension_for_pytorch==2.1.10+xpu as default pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -pip install transformers==4.36.0 datasets +pip install transformers==4.45.0 "trl<0.12.0" datasets pip install peft==0.10.0 -pip install bitsandbytes scipy +pip install bitsandbytes==0.45.1 scipy ``` ### 2. Configures OneAPI environment variables diff --git a/python/llm/example/GPU/LLM-Finetuning/QLoRA/simple-example/qlora_finetuning.py b/python/llm/example/GPU/LLM-Finetuning/QLoRA/simple-example/qlora_finetuning.py index b8a1fb8e..75479d64 100644 --- a/python/llm/example/GPU/LLM-Finetuning/QLoRA/simple-example/qlora_finetuning.py +++ b/python/llm/example/GPU/LLM-Finetuning/QLoRA/simple-example/qlora_finetuning.py @@ -18,7 +18,7 @@ import torch import os import transformers -from transformers import LlamaTokenizer +from transformers import AutoTokenizer from peft import LoraConfig from transformers import BitsAndBytesConfig from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training @@ -43,13 +43,13 @@ if __name__ == "__main__": args = parser.parse_args() model_path = args.repo_id_or_model_path dataset_path = args.dataset - tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) if dataset_path.endswith(".json") or dataset_path.endswith(".jsonl"): data = load_dataset("json", data_files=dataset_path) else: data = load_dataset(dataset_path) - + # For illustration purpose, only use part of data to train data = data["train"].train_test_split(train_size=0.1, shuffle=False) @@ -57,7 +57,7 @@ if __name__ == "__main__": prompter = Prompter("alpaca") train_data, _ = get_train_val_data(data, tokenizer, prompter, train_on_inputs=True, add_eos_token=False, cutoff_len=256, val_set_size=0, seed=42) - + bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=False, @@ -79,11 +79,11 @@ if __name__ == "__main__": # model.gradient_checkpointing_enable() model = prepare_model_for_kbit_training(model) config = LoraConfig( - r=8, - lora_alpha=32, - target_modules=["q_proj", "k_proj", "v_proj"], - lora_dropout=0.05, - bias="none", + r=8, + lora_alpha=32, + target_modules=["q_proj", "k_proj", "v_proj"], + lora_dropout=0.05, + bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, config) diff --git a/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/README.md b/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/README.md index 949b211e..f3eeb973 100644 --- a/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/README.md +++ b/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/README.md @@ -17,9 +17,9 @@ conda create -n llm python=3.11 conda activate llm # below command will install intel_extension_for_pytorch==2.1.10+xpu as default pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -pip install transformers==4.36.0 datasets +pip install transformers==4.45.0 "trl<0.12.0" datasets pip install peft==0.10.0 -pip install bitsandbytes scipy trl==0.9.6 +pip install bitsandbytes==0.45.1 scipy ``` ### 2. Configures OneAPI environment variables @@ -39,13 +39,13 @@ python ./qlora_finetuning.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH {'loss': 3.1854, 'learning_rate': 1.7777777777777777e-05, 'epoch': 0.03} {'loss': 3.0359, 'learning_rate': 1.555555555555556e-05, 'epoch': 0.05} {'loss': 2.9661, 'learning_rate': 1.3333333333333333e-05, 'epoch': 0.06} -{'loss': 2.7779, 'learning_rate': 1.1111111111111113e-05, 'epoch': 0.08} +{'loss': 2.7779, 'learning_rate': 1.1111111111111113e-05, 'epoch': 0.08} {'loss': 2.7795, 'learning_rate': 8.888888888888888e-06, 'epoch': 0.09} {'loss': 2.5149, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.11} {'loss': 2.5759, 'learning_rate': 4.444444444444444e-06, 'epoch': 0.12} {'loss': 2.5976, 'learning_rate': 2.222222222222222e-06, 'epoch': 0.14} {'loss': 2.5744, 'learning_rate': 0.0, 'epoch': 0.15} -{'train_runtime': 116.1914, 'train_samples_per_second': 6.885, 'train_steps_per_second': 1.721, 'train_loss': 2.819730052947998, 'epoch': 0.15} +{'train_runtime': 116.1914, 'train_samples_per_second': 6.885, 'train_steps_per_second': 1.721, 'train_loss': 2.819730052947998, 'epoch': 0.15} 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:56<00:00, 1.72it/s] TrainOutput(global_step=200, training_loss=2.819730052947998, metrics={'train_runtime': 116.1914, 'train_samples_per_second': 6.885, 'train_steps_per_second': 1.721, 'train_loss': 2.819730052947998, 'epoch': 0.15}) ``` diff --git a/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/qlora_finetuning.py b/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/qlora_finetuning.py index 4d1fb72c..c8ee8947 100644 --- a/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/qlora_finetuning.py +++ b/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/qlora_finetuning.py @@ -18,7 +18,7 @@ import torch import os import transformers -from transformers import LlamaTokenizer +from transformers import AutoTokenizer from peft import LoraConfig from transformers import BitsAndBytesConfig from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training @@ -44,7 +44,7 @@ if __name__ == "__main__": args = parser.parse_args() model_path = args.repo_id_or_model_path dataset_path = args.dataset - tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) # Avoid tokenizer doesn't have a padding token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -53,7 +53,7 @@ if __name__ == "__main__": data = load_dataset("json", data_files=dataset_path) else: data = load_dataset(dataset_path) - + # For illustration purpose, only use part of data to train data = data["train"].train_test_split(train_size=0.1, shuffle=False) @@ -82,11 +82,11 @@ if __name__ == "__main__": # it will slowdown the training speed model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) config = LoraConfig( - r=8, - lora_alpha=32, - target_modules=["q_proj", "k_proj", "v_proj"], - lora_dropout=0.05, - bias="none", + r=8, + lora_alpha=32, + target_modules=["q_proj", "k_proj", "v_proj"], + lora_dropout=0.05, + bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, config) diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 92d82f6a..1dd99684 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -51,7 +51,8 @@ from torch import Tensor, dtype, nn from operator import mul from functools import reduce from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd -from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_name +from ipex_llm.transformers.utils import is_autocast_enabled, get_autocast_dtype +from ipex_llm.transformers.utils import get_xpu_device_name from ipex_llm.transformers.convert import is_deepspeed_available, get_use_vllm T = TypeVar("T", bound="torch.nn.Module") @@ -527,8 +528,8 @@ class MatMulLowBit(torch.autograd.Function): A, weight = ctx.tensors grad_A, grad_weight = None, None if req_gradA: - if torch.xpu.is_autocast_xpu_enabled(): - grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype()) + if is_autocast_enabled("xpu"): + grad_output = grad_output.to(get_autocast_dtype("xpu")) if weight.qtype == NF4: dequant_weight = xe_linear.dequant(A, weight.data.view(torch.uint8), @@ -615,7 +616,7 @@ class LowBitLinear(nn.Linear): is_training = self.training and not torch.is_inference_mode_enabled() if is_training: # below logic is only for training - autocast_dtype = get_autocast_dtype(x) + autocast_dtype = get_autocast_dtype(x.device.type) if self.compute_dtype is not None and x.device.type == "xpu": x = x.to(self.compute_dtype) # solve GC issue for unlora module elif autocast_dtype is not None: diff --git a/python/llm/src/ipex_llm/transformers/qlora.py b/python/llm/src/ipex_llm/transformers/qlora.py index 1af0cf18..d6e32eec 100644 --- a/python/llm/src/ipex_llm/transformers/qlora.py +++ b/python/llm/src/ipex_llm/transformers/qlora.py @@ -109,7 +109,7 @@ class LoraLowBitLinear(Module, LoraLayer): self.qa_pool = torch.nn.Identity() def forward(self, x: torch.Tensor): - autocast_dtype = get_autocast_dtype(x) + autocast_dtype = get_autocast_dtype(x.device.type) if x.device.type == "xpu": # force to use bf16 on gpu x = x.to(torch.bfloat16) @@ -177,7 +177,7 @@ class LoraBF16Linear(Module, LoraLayer): self.is_target_conv_1d_layer = is_target_conv_1d_layer def forward(self, x: torch.Tensor): - autocast_dtype = get_autocast_dtype(x) + autocast_dtype = get_autocast_dtype(x.device.type) if x.device.type == "xpu": # force to use bf16 on gpu x = x.to(torch.bfloat16) diff --git a/python/llm/src/ipex_llm/transformers/utils.py b/python/llm/src/ipex_llm/transformers/utils.py index 5bd24667..3f351440 100644 --- a/python/llm/src/ipex_llm/transformers/utils.py +++ b/python/llm/src/ipex_llm/transformers/utils.py @@ -138,26 +138,39 @@ def fix_key(key): return key -def get_autocast_dtype(x): +def is_autocast_enabled(device_type: str): if torch.__version__ >= '2.3': - if torch.is_autocast_enabled(x.device.type): - return torch.get_autocast_dtype(x.device.type) + return torch.is_autocast_enabled(device_type) + else: + if device_type == "xpu": + return torch.xpu.is_autocast_xpu_enabled() + elif device_type == "cpu": + return torch.is_autocast_cpu_enabled() + else: + invalidInputError(False, + f"Device type {device_type} is not supported.") + + +def get_autocast_dtype(device_type: str): + if torch.__version__ >= '2.3': + if torch.is_autocast_enabled(device_type): + return torch.get_autocast_dtype(device_type) else: return None else: - if x.device.type == "xpu": + if device_type == "xpu": if torch.xpu.is_autocast_xpu_enabled(): return torch.xpu.get_autocast_xpu_dtype() else: return None - elif x.device.type == "cpu": + elif device_type == "cpu": if torch.is_autocast_cpu_enabled(): return torch.get_autocast_cpu_dtype() else: return None else: invalidInputError(False, - f"Device {x.device} is not supported.") + f"Device type {device_type} is not supported.") def get_xpu_device_name(device: torch.device): diff --git a/python/llm/src/ipex_llm/transformers/xpu_customize_fwd.py b/python/llm/src/ipex_llm/transformers/xpu_customize_fwd.py index bfed60d4..45cfd504 100644 --- a/python/llm/src/ipex_llm/transformers/xpu_customize_fwd.py +++ b/python/llm/src/ipex_llm/transformers/xpu_customize_fwd.py @@ -107,6 +107,8 @@ except ModuleNotFoundError: np = None # type: ignore[assignment] from typing import Any +from ipex_llm.transformers.utils import is_autocast_enabled, get_autocast_dtype + def _cast(value, dtype): if isinstance(value, torch.Tensor): @@ -155,12 +157,12 @@ def custom_fwd(fwd=None, *, cast_inputs=None): @functools.wraps(fwd) def decorate_fwd(*args, **kwargs): - args[0]._dtype = torch.xpu.get_autocast_xpu_dtype() + args[0]._dtype = get_autocast_dtype("xpu") if cast_inputs is None: - args[0]._fwd_used_autocast = torch.xpu.is_autocast_xpu_enabled() + args[0]._fwd_used_autocast = is_autocast_enabled("xpu") return fwd(*args, **kwargs) else: - autocast_context = torch.xpu.is_autocast_xpu_enabled() + autocast_context = is_autocast_enabled("xpu") args[0]._fwd_used_autocast = False if autocast_context: with torch.xpu.autocast(enabled=False): @@ -184,7 +186,7 @@ def custom_bwd(bwd): @functools.wraps(bwd) def decorate_bwd(*args, **kwargs): - with torch.xpu.autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype): + with torch.autocast("xpu", enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype): return bwd(*args, **kwargs) return decorate_bwd