Support deepspeed AutoTP (#9230)

* Support deepspeed

* add test script

* refactor convert

* refine example

* refine

* refine example

* fix style

* refine example and adapte latest ipex

* fix style
This commit is contained in:
Yang Wang 2023-10-25 14:46:28 +08:00 committed by GitHub
parent a6a8afc47e
commit 067c7e8098
5 changed files with 190 additions and 9 deletions

View file

@ -0,0 +1,103 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import torch
import transformers
import deepspeed
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
from bigdl.llm import optimize_model
import torch
import intel_extension_for_pytorch as ipex
import time
import argparse
from transformers import AutoModelForCausalLM # export AutoModelForCausalLM from transformers so that deepspeed use it
from transformers import LlamaTokenizer, AutoTokenizer
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun",
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
model = AutoModelForCausalLM.from_pretrained(args.repo_id_or_model_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
trust_remote_code=True,
use_cache=True)
model = deepspeed.init_inference(
model,
mp_size=world_size,
dtype=torch.float16,
replace_method="auto",
)
# move model to cpu and use bigdl-llm `optimize_model` to convert the
# model into optimized low bit format
# convert the rest of the model into float16 to reduce allreduce traffic
model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4').to(torch.float16)
# move model back to xpu
model = model.to(f'xpu:{local_rank}')
print(model)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Generate predicted tokens
with torch.inference_mode():
# prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
prompt = args.prompt
# input_str = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}')
# ipex model needs a warmup, then inference time can be accurate
output = model.generate(input_ids,
max_new_tokens=args.n_predict,
use_cache=True)
# start inference
st = time.time()
# if your selected model is capable of utilizing previous key/value attentions
# to enhance decoding speed, but has `"use_cache": false` in its model config,
# it is important to set `use_cache=True` explicitly in the `generate` function
# to obtain optimal performance with BigDL-LLM INT4 optimizations
output = model.generate(input_ids,
do_sample=False,
max_new_tokens=args.n_predict)
torch.xpu.synchronize()
end = time.time()
if local_rank == 0:
output = output.cpu()
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Inference time: {end-st} s')
print('-'*20, 'Prompt', '-'*20)
print(prompt)
print('-'*20, 'Output', '-'*20)
print(output_str)

View file

@ -0,0 +1,12 @@
source bigdl-llm-init -t -g
export MASTER_ADDR=127.0.0.1
export CCL_ZE_IPC_EXCHANGE=sockets
if [[ -n $OMP_NUM_THREADS ]]; then
export OMP_NUM_THREADS=$(($OMP_NUM_THREADS / 4))
else
export OMP_NUM_THREADS=$(($(nproc) / 4))
fi
torchrun --standalone \
--nnodes=1 \
--nproc-per-node 4 \
deepspeed_autotp.py --repo-id-or-model-path "meta-llama/Llama-2-7b-hf"

View file

@ -45,6 +45,42 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from .utils import logger from .utils import logger
def is_deepspeed_available():
return importlib.util.find_spec("deepspeed") is not None
def is_linear_module(module):
in_features = None
out_features = None
mp_group = None
if isinstance(module, nn.Linear):
in_features = module.in_features
out_features = module.out_features
mp_group = None
result = True
else:
if is_deepspeed_available():
from deepspeed.module_inject.layers import LinearLayer, LinearAllreduce
if isinstance(module, LinearLayer):
in_features = module.weight.shape[1]
out_features = module.weight.shape[0]
mp_group = None
result = True
elif isinstance(module, LinearAllreduce):
in_features = module.weight.shape[1]
out_features = module.weight.shape[0]
mp_group = module.mp_group
result = True
else:
result = False
else:
result = False
return result, (in_features, out_features, mp_group)
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False): current_key_name=None, convert_shape_only=False):
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear
@ -54,17 +90,20 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if current_key_name is None: if current_key_name is None:
current_key_name = [] current_key_name = []
if isinstance(module, nn.Linear) and name not in modules_to_not_convert: is_linear, linear_args = is_linear_module(module)
if is_linear and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert` # Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
in_features, out_features, mp_group = linear_args
with init_empty_weights(): with init_empty_weights():
new_linear = None new_linear = None
if qtype != ggml_tensor_qtype["fp16"]: if qtype != ggml_tensor_qtype["fp16"]:
new_linear = LowBitLinear( new_linear = LowBitLinear(
module.in_features, in_features,
module.out_features, out_features,
qtype, qtype,
module.bias is not None, module.bias is not None,
mp_group=mp_group,
) )
device_type = module.weight.data.device.type device_type = module.weight.data.device.type
@ -82,10 +121,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if module.in_features in [4096, 11008]: if module.in_features in [4096, 11008]:
# esimd fp16 path # esimd fp16 path
new_linear = FP16Linear( new_linear = FP16Linear(
module.in_features, in_features,
module.out_features, out_features,
qtype, qtype,
module.bias is not None, module.bias is not None,
mp_group=mp_group,
) )
device_type = module.weight.data.device.type device_type = module.weight.data.device.type

View file

@ -328,7 +328,7 @@ class MatMulLowBit(torch.autograd.Function):
class LowBitLinear(nn.Linear): class LowBitLinear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True, def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True): conver_to_half=True, mp_group=None):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
self.weight = FP4Params(self.weight.data, self.weight = FP4Params(self.weight.data,
requires_grad=False, requires_grad=False,
@ -339,6 +339,7 @@ class LowBitLinear(nn.Linear):
self.weight_length = self.out_len * self.in_len self.weight_length = self.out_len * self.in_len
self.qtype = qtype self.qtype = qtype
self.conver_to_half = conver_to_half self.conver_to_half = conver_to_half
self.mp_group = mp_group
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if self.bias is not None and self.bias.dtype != x.dtype: if self.bias is not None and self.bias.dtype != x.dtype:
@ -378,6 +379,9 @@ class LowBitLinear(nn.Linear):
input_seq_size) input_seq_size)
new_shape = x_shape[:-1] + (self.out_len,) new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape) result = result.view(new_shape)
if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
if self.bias is not None: if self.bias is not None:
result += self.bias result += self.bias
else: else:
@ -400,7 +404,7 @@ class LowBitLinear(nn.Linear):
class FP16Linear(nn.Linear): class FP16Linear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True, def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True): conver_to_half=True, mp_group=None):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
self.in_len = input_features self.in_len = input_features
self.out_len = output_features self.out_len = output_features
@ -408,6 +412,7 @@ class FP16Linear(nn.Linear):
self.weight_length = self.out_len * self.in_len self.weight_length = self.out_len * self.in_len
self.qtype = qtype self.qtype = qtype
self.conver_to_half = conver_to_half self.conver_to_half = conver_to_half
self.mp_group = mp_group
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if self.bias is not None and self.bias.dtype != x.dtype: if self.bias is not None and self.bias.dtype != x.dtype:
@ -442,6 +447,9 @@ class FP16Linear(nn.Linear):
new_shape = x_shape[:-1] + (self.out_len,) new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape) result = result.view(new_shape)
if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
if self.bias is not None: if self.bias is not None:
result += self.bias result += self.bias

View file

@ -32,6 +32,7 @@
# limitations under the License. # limitations under the License.
import torch import torch
import importlib
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple from typing import Optional, Tuple
import math import math
@ -58,10 +59,27 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def get_ipex_version():
if importlib.util.find_spec("intel_extension_for_pytorch") is not None:
import intel_extension_for_pytorch as ipex
return ipex.__version__
else:
return None
ipex_version = get_ipex_version()
def llama_rms_norm_forward(self, hidden_states): def llama_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, if ipex_version == "2.0.110+xpu":
[self.weight.size(0)], self.weight) hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
else:
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight,
self.variance_epsilon)
else: else:
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)