Fixes for xpu Bf16 training (#9156)

* Support bf16 training

* Use a stable transformer version

* remove env

* fix style
This commit is contained in:
Yang Wang 2023-10-15 12:28:59 +08:00 committed by GitHub
parent 51a133de56
commit 7a2de00b48
4 changed files with 10 additions and 4 deletions

View file

@ -17,7 +17,7 @@ conda activate llm
# below command will install intel_extension_for_pytorch==2.0.110+xpu as default # below command will install intel_extension_for_pytorch==2.0.110+xpu as default
# you can install specific ipex/torch version for your need # you can install specific ipex/torch version for your need
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
pip install git+https://github.com/huggingface/transformers.git@95fe0f5 pip install transformers==4.34.0
pip install peft==0.5.0 pip install peft==0.5.0
``` ```

View file

@ -16,8 +16,6 @@
import torch import torch
import os import os
os.environ["ACCELERATE_USE_IPEX"] = "true"
os.environ["ACCELERATE_USE_XPU"] = "true"
import transformers import transformers
from transformers import LlamaTokenizer from transformers import LlamaTokenizer

View file

@ -308,7 +308,7 @@ class MatMulLowBit(torch.autograd.Function):
A, weight = ctx.tensors A, weight = ctx.tensors
grad_A, grad_weight = None, None grad_A, grad_weight = None, None
if req_gradA: if req_gradA:
dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype) dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype).to(grad_output.dtype)
grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape)) grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape))
return grad_A, grad_weight, None return grad_A, grad_weight, None

View file

@ -188,3 +188,11 @@ class PeftModel:
LoraModel._create_new_module = old_create_new_module LoraModel._create_new_module = old_create_new_module
return model return model
def patch_prepare_ipex(self, *args):
return tuple(args)
# workaround a IPEX bug that prevents resume training in bf16
from accelerate import Accelerator
Accelerator._prepare_ipex = patch_prepare_ipex