LLM: fix BF16Linear related training & inference issue (#9755)
* fix bf16 related issue * fix * update based on comment & add arc lora script * update readme * update based on comment * update based on comment * update * force to bf16 * fix style * move check input dtype into function * update convert * meet code review * meet code review * update merged model to support new training_mode api * fix typo
This commit is contained in:
parent
30dab36f76
commit
1917bbe626
8 changed files with 75 additions and 20 deletions
|
|
@ -101,6 +101,12 @@ bash qalora_finetune_llama2_7b_pvc_1550_1_tile.sh
|
|||
|
||||
#### LoRA
|
||||
|
||||
##### Finetuning LLaMA2-7B on single Arc A770
|
||||
|
||||
```bash
|
||||
bash lora_finetune_llama2_7b_arc_1_card.sh
|
||||
```
|
||||
|
||||
##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1100
|
||||
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file
|
||||
python ./alpaca_qlora_finetuning.py \
|
||||
--micro_batch_size 8 \
|
||||
--batch_size 128 \
|
||||
--base_model "meta-llama/Llama-2-7b-hf" \
|
||||
--data_path "yahma/alpaca-cleaned" \
|
||||
--output_dir "./bigdl-lora-alpaca" \
|
||||
--gradient_checkpointing True \
|
||||
--lora_target_modules "['k_proj', 'q_proj', 'o_proj', 'v_proj']" \
|
||||
--training_mode "lora"
|
||||
|
|
@ -54,7 +54,8 @@ if __name__ == "__main__":
|
|||
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||
|
||||
lora_config = LoraConfig.from_json_file(os.path.join(adapter_path, "adapter_config.json"))
|
||||
qa_lora = lora_config.get("qa_lora", False)
|
||||
training_mode = lora_config.get("training_mode", "qlora")
|
||||
qa_lora = training_mode == "qalora"
|
||||
|
||||
temp_dir = None
|
||||
if qa_lora:
|
||||
|
|
|
|||
|
|
@ -262,6 +262,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
.to(device_type)
|
||||
elif qtype == ggml_tensor_qtype["bf16"]:
|
||||
module.to(torch.bfloat16)
|
||||
new_linear = BF16Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
|
|
@ -344,7 +345,7 @@ def _optimize_pre(model):
|
|||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||
convert_shape_only=False, device="cpu",
|
||||
modules_to_not_convert=None, cpu_embedding=False,
|
||||
lightweight_bmm=False):
|
||||
lightweight_bmm=False, torch_dtype="auto"):
|
||||
logger.info(f"Converting the current model to "
|
||||
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
|
||||
f"format......")
|
||||
|
|
@ -366,7 +367,10 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
)
|
||||
elif device == "cpu":
|
||||
if not (getattr(model, "quantization_method", None) == "gptq"):
|
||||
model.to(torch.float32)
|
||||
if torch_dtype == "auto":
|
||||
convert_bigdl_other_module(model, torch.float32)
|
||||
else:
|
||||
convert_bigdl_other_module(model, torch_dtype)
|
||||
elif device == "meta":
|
||||
# Do nothing here for weights are empty.
|
||||
pass
|
||||
|
|
@ -376,6 +380,17 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
return model
|
||||
|
||||
|
||||
def convert_bigdl_other_module(model, dtype):
|
||||
# Convert modules outside of bigdl linear to corresponding dtype
|
||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, \
|
||||
FP16Linear, BF16Linear
|
||||
for module in model.modules():
|
||||
if list(module.children()) == []:
|
||||
# leaf module
|
||||
if not isinstance(module, (LowBitLinear, FP16Linear, BF16Linear)):
|
||||
module.to(dtype)
|
||||
|
||||
|
||||
def convert_forward(m, target_m, new_forward):
|
||||
for _, sub_m in m.named_children():
|
||||
if isinstance(sub_m, target_m):
|
||||
|
|
|
|||
|
|
@ -599,27 +599,18 @@ class BF16Linear(nn.Linear):
|
|||
self.out_len = output_features
|
||||
self.weight_shape = (self.out_len, self.in_len)
|
||||
self.weight_length = self.out_len * self.in_len
|
||||
self.qtype = ggml_tensor_qtype["bf16"]
|
||||
self.mp_group = mp_group
|
||||
self.compute_dtype = compute_dtype
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# only work for GPU now
|
||||
invalidInputError(x.device.type == "xpu",
|
||||
"bf16 only works for GPU now")
|
||||
is_training = self.training and not torch.is_inference_mode_enabled()
|
||||
if is_training:
|
||||
# below logic is only for training
|
||||
autocast_dtype = get_autocast_dtype(x)
|
||||
if self.compute_dtype is not None and x.device.type == "xpu":
|
||||
x = x.to(self.compute_dtype) # solve GC issue for unlora module
|
||||
elif autocast_dtype is not None:
|
||||
x = x.to(autocast_dtype)
|
||||
|
||||
x = x.to(torch.bfloat16)
|
||||
if self.weight is not None and self.weight.dtype != x.dtype:
|
||||
self.weight.data = self.weight.data.to(x.dtype)
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
result = F.linear(x, self.weight)
|
||||
if self.bias is not None:
|
||||
result += self.bias
|
||||
|
||||
return result.to(x.dtype)
|
||||
|
|
|
|||
|
|
@ -304,7 +304,8 @@ class _BaseAutoModelClass:
|
|||
model = model.to("cpu")
|
||||
model = ggml_convert_low_bit(model, qtype, optimize_model,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm)
|
||||
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
|
||||
torch_dtype=kwargs.get("torch_dtype", 'auto'))
|
||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||
model.config.update({"tie_word_embeddings": False})
|
||||
|
||||
|
|
|
|||
|
|
@ -519,6 +519,9 @@ def check_flash_attention_available(query):
|
|||
# ipex flash attention is only supported for xetla
|
||||
# may update this later
|
||||
return False
|
||||
if query.dtype not in [torch.float32, torch.float16]:
|
||||
# only use flash attention for fp32/fp16 input
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -49,10 +49,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch.nn import Linear, Embedding
|
||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, BF16Linear, get_qk_size
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.utils import get_autocast_dtype
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
import functools
|
||||
from bigdl.llm.transformers import training_patch
|
||||
|
||||
|
|
@ -275,9 +277,19 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
|
|||
|
||||
if not is_gptq_quantized:
|
||||
# cast all non INT8 parameters to fp32
|
||||
for param in model.parameters():
|
||||
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
|
||||
param.data = param.data.to(torch.float32)
|
||||
# for param in model.parameters():
|
||||
# if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
|
||||
# param.data = param.data.to(torch.float32)
|
||||
|
||||
# change to below way to reduce memory for Linear
|
||||
# otherwise lora finetuning on arc may OOM at this convert
|
||||
for module in model.modules():
|
||||
if list(module.children()) == []:
|
||||
# leaf module
|
||||
if not isinstance(module, (Linear, Embedding)):
|
||||
for param in module.parameters():
|
||||
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
# For backward compatibility
|
||||
|
|
|
|||
Loading…
Reference in a new issue