Fixes for xpu Bf16 training (#9156)
* Support bf16 training * Use a stable transformer version * remove env * fix style
This commit is contained in:
parent
51a133de56
commit
7a2de00b48
4 changed files with 10 additions and 4 deletions
|
|
@ -17,7 +17,7 @@ conda activate llm
|
||||||
# below command will install intel_extension_for_pytorch==2.0.110+xpu as default
|
# below command will install intel_extension_for_pytorch==2.0.110+xpu as default
|
||||||
# you can install specific ipex/torch version for your need
|
# you can install specific ipex/torch version for your need
|
||||||
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
|
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
|
||||||
pip install git+https://github.com/huggingface/transformers.git@95fe0f5
|
pip install transformers==4.34.0
|
||||||
pip install peft==0.5.0
|
pip install peft==0.5.0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,6 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
os.environ["ACCELERATE_USE_IPEX"] = "true"
|
|
||||||
os.environ["ACCELERATE_USE_XPU"] = "true"
|
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import LlamaTokenizer
|
from transformers import LlamaTokenizer
|
||||||
|
|
|
||||||
|
|
@ -308,7 +308,7 @@ class MatMulLowBit(torch.autograd.Function):
|
||||||
A, weight = ctx.tensors
|
A, weight = ctx.tensors
|
||||||
grad_A, grad_weight = None, None
|
grad_A, grad_weight = None, None
|
||||||
if req_gradA:
|
if req_gradA:
|
||||||
dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype)
|
dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype).to(grad_output.dtype)
|
||||||
grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape))
|
grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape))
|
||||||
|
|
||||||
return grad_A, grad_weight, None
|
return grad_A, grad_weight, None
|
||||||
|
|
|
||||||
|
|
@ -188,3 +188,11 @@ class PeftModel:
|
||||||
LoraModel._create_new_module = old_create_new_module
|
LoraModel._create_new_module = old_create_new_module
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def patch_prepare_ipex(self, *args):
|
||||||
|
return tuple(args)
|
||||||
|
|
||||||
|
# workaround a IPEX bug that prevents resume training in bf16
|
||||||
|
from accelerate import Accelerator
|
||||||
|
Accelerator._prepare_ipex = patch_prepare_ipex
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue