diff --git a/python/llm/example/GPU/QLoRA-FineTuning/README.md b/python/llm/example/GPU/QLoRA-FineTuning/README.md index 7b98b1b6..237c642d 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/README.md +++ b/python/llm/example/GPU/QLoRA-FineTuning/README.md @@ -17,7 +17,7 @@ conda activate llm # below command will install intel_extension_for_pytorch==2.0.110+xpu as default # you can install specific ipex/torch version for your need pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu -pip install git+https://github.com/huggingface/transformers.git@95fe0f5 +pip install transformers==4.34.0 pip install peft==0.5.0 ``` diff --git a/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py b/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py index 85b5642e..9c8e2d1f 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py +++ b/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py @@ -16,8 +16,6 @@ import torch import os -os.environ["ACCELERATE_USE_IPEX"] = "true" -os.environ["ACCELERATE_USE_XPU"] = "true" import transformers from transformers import LlamaTokenizer diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 0e3b7dba..d9a0ffe2 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -308,7 +308,7 @@ class MatMulLowBit(torch.autograd.Function): A, weight = ctx.tensors grad_A, grad_weight = None, None if req_gradA: - dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype) + dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype).to(grad_output.dtype) grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape)) return grad_A, grad_weight, None diff --git a/python/llm/src/bigdl/llm/transformers/qlora.py b/python/llm/src/bigdl/llm/transformers/qlora.py index 2b074105..a52abd94 100644 --- a/python/llm/src/bigdl/llm/transformers/qlora.py +++ b/python/llm/src/bigdl/llm/transformers/qlora.py @@ -188,3 +188,11 @@ class PeftModel: LoraModel._create_new_module = old_create_new_module return model + + +def patch_prepare_ipex(self, *args): + return tuple(args) + +# workaround a IPEX bug that prevents resume training in bf16 +from accelerate import Accelerator +Accelerator._prepare_ipex = patch_prepare_ipex