LLM: fix BF16Linear related training & inference issue (#9755)
* fix bf16 related issue * fix * update based on comment & add arc lora script * update readme * update based on comment * update based on comment * update * force to bf16 * fix style * move check input dtype into function * update convert * meet code review * meet code review * update merged model to support new training_mode api * fix typo
This commit is contained in:
parent
30dab36f76
commit
1917bbe626
8 changed files with 75 additions and 20 deletions
|
|
@ -101,6 +101,12 @@ bash qalora_finetune_llama2_7b_pvc_1550_1_tile.sh
|
||||||
|
|
||||||
#### LoRA
|
#### LoRA
|
||||||
|
|
||||||
|
##### Finetuning LLaMA2-7B on single Arc A770
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash lora_finetune_llama2_7b_arc_1_card.sh
|
||||||
|
```
|
||||||
|
|
||||||
##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1100
|
##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1100
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file
|
||||||
|
python ./alpaca_qlora_finetuning.py \
|
||||||
|
--micro_batch_size 8 \
|
||||||
|
--batch_size 128 \
|
||||||
|
--base_model "meta-llama/Llama-2-7b-hf" \
|
||||||
|
--data_path "yahma/alpaca-cleaned" \
|
||||||
|
--output_dir "./bigdl-lora-alpaca" \
|
||||||
|
--gradient_checkpointing True \
|
||||||
|
--lora_target_modules "['k_proj', 'q_proj', 'o_proj', 'v_proj']" \
|
||||||
|
--training_mode "lora"
|
||||||
|
|
@ -54,7 +54,8 @@ if __name__ == "__main__":
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||||
|
|
||||||
lora_config = LoraConfig.from_json_file(os.path.join(adapter_path, "adapter_config.json"))
|
lora_config = LoraConfig.from_json_file(os.path.join(adapter_path, "adapter_config.json"))
|
||||||
qa_lora = lora_config.get("qa_lora", False)
|
training_mode = lora_config.get("training_mode", "qlora")
|
||||||
|
qa_lora = training_mode == "qalora"
|
||||||
|
|
||||||
temp_dir = None
|
temp_dir = None
|
||||||
if qa_lora:
|
if qa_lora:
|
||||||
|
|
|
||||||
|
|
@ -262,6 +262,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
.to(device_type)
|
.to(device_type)
|
||||||
elif qtype == ggml_tensor_qtype["bf16"]:
|
elif qtype == ggml_tensor_qtype["bf16"]:
|
||||||
|
module.to(torch.bfloat16)
|
||||||
new_linear = BF16Linear(
|
new_linear = BF16Linear(
|
||||||
in_features,
|
in_features,
|
||||||
out_features,
|
out_features,
|
||||||
|
|
@ -344,7 +345,7 @@ def _optimize_pre(model):
|
||||||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
convert_shape_only=False, device="cpu",
|
convert_shape_only=False, device="cpu",
|
||||||
modules_to_not_convert=None, cpu_embedding=False,
|
modules_to_not_convert=None, cpu_embedding=False,
|
||||||
lightweight_bmm=False):
|
lightweight_bmm=False, torch_dtype="auto"):
|
||||||
logger.info(f"Converting the current model to "
|
logger.info(f"Converting the current model to "
|
||||||
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
|
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
|
||||||
f"format......")
|
f"format......")
|
||||||
|
|
@ -366,7 +367,10 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
)
|
)
|
||||||
elif device == "cpu":
|
elif device == "cpu":
|
||||||
if not (getattr(model, "quantization_method", None) == "gptq"):
|
if not (getattr(model, "quantization_method", None) == "gptq"):
|
||||||
model.to(torch.float32)
|
if torch_dtype == "auto":
|
||||||
|
convert_bigdl_other_module(model, torch.float32)
|
||||||
|
else:
|
||||||
|
convert_bigdl_other_module(model, torch_dtype)
|
||||||
elif device == "meta":
|
elif device == "meta":
|
||||||
# Do nothing here for weights are empty.
|
# Do nothing here for weights are empty.
|
||||||
pass
|
pass
|
||||||
|
|
@ -376,6 +380,17 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def convert_bigdl_other_module(model, dtype):
|
||||||
|
# Convert modules outside of bigdl linear to corresponding dtype
|
||||||
|
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, \
|
||||||
|
FP16Linear, BF16Linear
|
||||||
|
for module in model.modules():
|
||||||
|
if list(module.children()) == []:
|
||||||
|
# leaf module
|
||||||
|
if not isinstance(module, (LowBitLinear, FP16Linear, BF16Linear)):
|
||||||
|
module.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
def convert_forward(m, target_m, new_forward):
|
def convert_forward(m, target_m, new_forward):
|
||||||
for _, sub_m in m.named_children():
|
for _, sub_m in m.named_children():
|
||||||
if isinstance(sub_m, target_m):
|
if isinstance(sub_m, target_m):
|
||||||
|
|
|
||||||
|
|
@ -599,27 +599,18 @@ class BF16Linear(nn.Linear):
|
||||||
self.out_len = output_features
|
self.out_len = output_features
|
||||||
self.weight_shape = (self.out_len, self.in_len)
|
self.weight_shape = (self.out_len, self.in_len)
|
||||||
self.weight_length = self.out_len * self.in_len
|
self.weight_length = self.out_len * self.in_len
|
||||||
|
self.qtype = ggml_tensor_qtype["bf16"]
|
||||||
self.mp_group = mp_group
|
self.mp_group = mp_group
|
||||||
self.compute_dtype = compute_dtype
|
self.compute_dtype = compute_dtype
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
# only work for GPU now
|
x = x.to(torch.bfloat16)
|
||||||
invalidInputError(x.device.type == "xpu",
|
if self.weight is not None and self.weight.dtype != x.dtype:
|
||||||
"bf16 only works for GPU now")
|
self.weight.data = self.weight.data.to(x.dtype)
|
||||||
is_training = self.training and not torch.is_inference_mode_enabled()
|
|
||||||
if is_training:
|
|
||||||
# below logic is only for training
|
|
||||||
autocast_dtype = get_autocast_dtype(x)
|
|
||||||
if self.compute_dtype is not None and x.device.type == "xpu":
|
|
||||||
x = x.to(self.compute_dtype) # solve GC issue for unlora module
|
|
||||||
elif autocast_dtype is not None:
|
|
||||||
x = x.to(autocast_dtype)
|
|
||||||
|
|
||||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
self.bias.data = self.bias.data.to(x.dtype)
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
result = F.linear(x, self.weight)
|
result = F.linear(x, self.weight)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
result += self.bias
|
result += self.bias
|
||||||
|
|
||||||
return result.to(x.dtype)
|
return result.to(x.dtype)
|
||||||
|
|
|
||||||
|
|
@ -304,7 +304,8 @@ class _BaseAutoModelClass:
|
||||||
model = model.to("cpu")
|
model = model.to("cpu")
|
||||||
model = ggml_convert_low_bit(model, qtype, optimize_model,
|
model = ggml_convert_low_bit(model, qtype, optimize_model,
|
||||||
modules_to_not_convert=modules_to_not_convert,
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm)
|
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
|
||||||
|
torch_dtype=kwargs.get("torch_dtype", 'auto'))
|
||||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||||
model.config.update({"tie_word_embeddings": False})
|
model.config.update({"tie_word_embeddings": False})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -519,6 +519,9 @@ def check_flash_attention_available(query):
|
||||||
# ipex flash attention is only supported for xetla
|
# ipex flash attention is only supported for xetla
|
||||||
# may update this later
|
# may update this later
|
||||||
return False
|
return False
|
||||||
|
if query.dtype not in [torch.float32, torch.float16]:
|
||||||
|
# only use flash attention for fp32/fp16 input
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,10 +49,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn import Linear, Embedding
|
||||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, BF16Linear, get_qk_size
|
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, BF16Linear, get_qk_size
|
||||||
from peft.tuners.lora import LoraLayer
|
from peft.tuners.lora import LoraLayer
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.transformers.utils import get_autocast_dtype
|
from bigdl.llm.transformers.utils import get_autocast_dtype
|
||||||
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
import functools
|
import functools
|
||||||
from bigdl.llm.transformers import training_patch
|
from bigdl.llm.transformers import training_patch
|
||||||
|
|
||||||
|
|
@ -275,7 +277,17 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
|
||||||
|
|
||||||
if not is_gptq_quantized:
|
if not is_gptq_quantized:
|
||||||
# cast all non INT8 parameters to fp32
|
# cast all non INT8 parameters to fp32
|
||||||
for param in model.parameters():
|
# for param in model.parameters():
|
||||||
|
# if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
|
||||||
|
# param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
# change to below way to reduce memory for Linear
|
||||||
|
# otherwise lora finetuning on arc may OOM at this convert
|
||||||
|
for module in model.modules():
|
||||||
|
if list(module.children()) == []:
|
||||||
|
# leaf module
|
||||||
|
if not isinstance(module, (Linear, Embedding)):
|
||||||
|
for param in module.parameters():
|
||||||
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
|
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue